Skip to content

Commit ed0fc15

Browse files
committed
port avx512 implementation of compare256
1 parent bc0c9fd commit ed0fc15

File tree

1 file changed

+90
-0
lines changed

1 file changed

+90
-0
lines changed

zlib-rs/src/deflate/compare256.rs

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,10 @@ pub fn compare256_slice(src0: &[u8], src1: &[u8]) -> usize {
99
}
1010

1111
fn compare256(src0: &[u8; 256], src1: &[u8; 256]) -> usize {
12+
if cfg!(target_feature = "avx512vl") && cfg!(target_feature = "avx512bw") {
13+
return unsafe { avx512::compare256(src0, src1) };
14+
}
15+
1216
#[cfg(target_arch = "x86_64")]
1317
if crate::cpu_features::is_enabled_avx2_and_bmi2() {
1418
return unsafe { avx2::compare256(src0, src1) };
@@ -263,6 +267,92 @@ mod avx2 {
263267
}
264268
}
265269

270+
#[cfg(target_arch = "x86_64")]
271+
mod avx512 {
272+
use core::arch::x86_64::{
273+
_mm512_cmpeq_epu8_mask, _mm512_loadu_si512, _mm_cmpeq_epu8_mask, _mm_loadu_si128,
274+
};
275+
276+
/// # Safety
277+
///
278+
/// Behavior is undefined if the `avx` target feature is not enabled
279+
#[target_feature(enable = "avx512vl")]
280+
#[target_feature(enable = "avx512bw")]
281+
pub unsafe fn compare256(src0: &[u8; 256], src1: &[u8; 256]) -> usize {
282+
// First do a 16byte round before increasing to 64bytes, this reduces the
283+
// penalty for the short matches, and those are usually the most common ones.
284+
// This requires us to overlap on the last round, giving a small penalty
285+
// on matches of 192+ bytes (Still faster than AVX2 though).
286+
287+
unsafe {
288+
// 16 bytes
289+
let xmm_src0_0 = _mm_loadu_si128(src0.as_ptr().cast());
290+
let xmm_src1_0 = _mm_loadu_si128(src1.as_ptr().cast());
291+
let mask_0 = u32::from(_mm_cmpeq_epu8_mask(xmm_src0_0, xmm_src1_0)); // zero-extended to use __builtin_ctz
292+
if mask_0 != 0x0000FFFF {
293+
// There is potential for using __builtin_ctzg/__builtin_ctzs/_tzcnt_u16/__tzcnt_u16 here
294+
let match_byte = mask_0.trailing_ones();
295+
return match_byte as usize;
296+
}
297+
298+
// 64 bytes
299+
let zmm_src0_1 = _mm512_loadu_si512(src0[16..].as_ptr().cast());
300+
let zmm_src1_1 = _mm512_loadu_si512(src1[16..].as_ptr().cast());
301+
let mask_1 = _mm512_cmpeq_epu8_mask(zmm_src0_1, zmm_src1_1);
302+
if mask_1 != 0xFFFFFFFFFFFFFFFF {
303+
let match_byte = mask_1.trailing_ones();
304+
return 16 + match_byte as usize;
305+
}
306+
307+
// 64 bytes
308+
let zmm_src0_2 = _mm512_loadu_si512(src0[80..].as_ptr().cast());
309+
let zmm_src1_2 = _mm512_loadu_si512(src1[80..].as_ptr().cast());
310+
let mask_2 = _mm512_cmpeq_epu8_mask(zmm_src0_2, zmm_src1_2);
311+
if mask_2 != 0xFFFFFFFFFFFFFFFF {
312+
let match_byte = mask_2.trailing_ones();
313+
return 80 + match_byte as usize;
314+
}
315+
316+
// 64 bytes
317+
let zmm_src0_3 = _mm512_loadu_si512(src0[144..].as_ptr().cast());
318+
let zmm_src1_3 = _mm512_loadu_si512(src1[144..].as_ptr().cast());
319+
let mask_3 = _mm512_cmpeq_epu8_mask(zmm_src0_3, zmm_src1_3);
320+
if mask_3 != 0xFFFFFFFFFFFFFFFF {
321+
let match_byte = mask_3.trailing_ones();
322+
return 144 + match_byte as usize;
323+
}
324+
325+
// 64 bytes (overlaps the previous 16 bytes for fast tail processing)
326+
let zmm_src0_4 = _mm512_loadu_si512(src0[192..].as_ptr().cast());
327+
let zmm_src1_4 = _mm512_loadu_si512(src1[192..].as_ptr().cast());
328+
let mask_4 = _mm512_cmpeq_epu8_mask(zmm_src0_4, zmm_src1_4);
329+
if mask_4 != 0xFFFFFFFFFFFFFFFF {
330+
let match_byte = mask_4.trailing_ones();
331+
return 192 + match_byte as usize;
332+
}
333+
}
334+
335+
256
336+
}
337+
338+
#[test]
339+
fn test_compare256() {
340+
if true {
341+
let str1 = [b'a'; super::MAX_COMPARE_SIZE];
342+
let mut str2 = [b'a'; super::MAX_COMPARE_SIZE];
343+
344+
for i in 0..str1.len() {
345+
str2[i] = 0;
346+
347+
let match_len = unsafe { compare256(&str1, &str2) };
348+
assert_eq!(match_len, i);
349+
350+
str2[i] = b'a';
351+
}
352+
}
353+
}
354+
}
355+
266356
#[cfg(target_arch = "wasm32")]
267357
mod wasm32 {
268358
use core::arch::wasm32::{u8x16_bitmask, u8x16_eq, v128, v128_load};

0 commit comments

Comments
 (0)