Skip to content

Commit c120876

Browse files
committed
port avx512 implementation of compare256
1 parent bc0c9fd commit c120876

File tree

3 files changed

+99
-1
lines changed

3 files changed

+99
-1
lines changed

.github/workflows/checks.yaml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -492,6 +492,10 @@ jobs:
492492
run: "cargo +nightly miri nextest run -j4 -p zlib-rs --target ${{ matrix.target }} --features=vpclmulqdq crc32::"
493493
env:
494494
RUSTFLAGS: "-Ctarget-feature=+vpclmulqdq,+avx512f"
495+
- name: Test avx512 compare256 implementation
496+
run: "cargo +nightly miri nextest run -j4 -p zlib-rs --target ${{ matrix.target }} --features=avx512 compare256"
497+
env:
498+
RUSTFLAGS: "-Ctarget-feature=+avx512vl,+avx512bw"
495499
- name: Test allocator with miri
496500
run: "cargo +nightly miri nextest run -j4 -p zlib-rs --target ${{ matrix.target }} allocate::"
497501
- name: Test gz logic with miri

zlib-rs/Cargo.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,8 @@ __internal-fuzz = ["arbitrary"]
2424
__internal-fuzz-disable-checksum = [] # disable checksum validation on inflate
2525
__internal-test = ["quickcheck"]
2626
ZLIB_DEBUG = []
27-
vpclmulqdq = [] # use avx512 to speed up crc32. Only stable from 1.89.0 onwards
27+
vpclmulqdq = [] # use avx512 to speed up crc32. Only stable from 1.89.0 onwards.
28+
avx512 = ["vpclmulqdq"] # use avx512 to speed up crc32 and adler32. Only stable from 1.89.0 onwards.
2829

2930

3031
[dependencies]

zlib-rs/src/deflate/compare256.rs

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,12 @@ pub fn compare256_slice(src0: &[u8], src1: &[u8]) -> usize {
99
}
1010

1111
fn compare256(src0: &[u8; 256], src1: &[u8; 256]) -> usize {
12+
#[cfg(feature = "avx512")]
13+
#[cfg(target_arch = "x86_64")]
14+
if cfg!(target_feature = "avx512vl") && cfg!(target_feature = "avx512bw") {
15+
return unsafe { avx512::compare256(src0, src1) };
16+
}
17+
1218
#[cfg(target_arch = "x86_64")]
1319
if crate::cpu_features::is_enabled_avx2_and_bmi2() {
1420
return unsafe { avx2::compare256(src0, src1) };
@@ -263,6 +269,93 @@ mod avx2 {
263269
}
264270
}
265271

272+
#[cfg(feature = "avx512")]
273+
#[cfg(target_arch = "x86_64")]
274+
mod avx512 {
275+
use core::arch::x86_64::{
276+
_mm512_cmpeq_epu8_mask, _mm512_loadu_si512, _mm_cmpeq_epu8_mask, _mm_loadu_si128,
277+
};
278+
279+
/// # Safety
280+
///
281+
/// Behavior is undefined if the `avx` target feature is not enabled
282+
#[target_feature(enable = "avx512vl")]
283+
#[target_feature(enable = "avx512bw")]
284+
pub unsafe fn compare256(src0: &[u8; 256], src1: &[u8; 256]) -> usize {
285+
// First do a 16byte round before increasing to 64bytes, this reduces the
286+
// penalty for the short matches, and those are usually the most common ones.
287+
// This requires us to overlap on the last round, giving a small penalty
288+
// on matches of 192+ bytes (Still faster than AVX2 though).
289+
290+
unsafe {
291+
// 16 bytes
292+
let xmm_src0_0 = _mm_loadu_si128(src0.as_ptr().cast());
293+
let xmm_src1_0 = _mm_loadu_si128(src1.as_ptr().cast());
294+
let mask_0 = u32::from(_mm_cmpeq_epu8_mask(xmm_src0_0, xmm_src1_0)); // zero-extended to use __builtin_ctz
295+
if mask_0 != 0x0000FFFF {
296+
// There is potential for using __builtin_ctzg/__builtin_ctzs/_tzcnt_u16/__tzcnt_u16 here
297+
let match_byte = mask_0.trailing_ones();
298+
return match_byte as usize;
299+
}
300+
301+
// 64 bytes
302+
let zmm_src0_1 = _mm512_loadu_si512(src0[16..].as_ptr().cast());
303+
let zmm_src1_1 = _mm512_loadu_si512(src1[16..].as_ptr().cast());
304+
let mask_1 = _mm512_cmpeq_epu8_mask(zmm_src0_1, zmm_src1_1);
305+
if mask_1 != 0xFFFFFFFFFFFFFFFF {
306+
let match_byte = mask_1.trailing_ones();
307+
return 16 + match_byte as usize;
308+
}
309+
310+
// 64 bytes
311+
let zmm_src0_2 = _mm512_loadu_si512(src0[80..].as_ptr().cast());
312+
let zmm_src1_2 = _mm512_loadu_si512(src1[80..].as_ptr().cast());
313+
let mask_2 = _mm512_cmpeq_epu8_mask(zmm_src0_2, zmm_src1_2);
314+
if mask_2 != 0xFFFFFFFFFFFFFFFF {
315+
let match_byte = mask_2.trailing_ones();
316+
return 80 + match_byte as usize;
317+
}
318+
319+
// 64 bytes
320+
let zmm_src0_3 = _mm512_loadu_si512(src0[144..].as_ptr().cast());
321+
let zmm_src1_3 = _mm512_loadu_si512(src1[144..].as_ptr().cast());
322+
let mask_3 = _mm512_cmpeq_epu8_mask(zmm_src0_3, zmm_src1_3);
323+
if mask_3 != 0xFFFFFFFFFFFFFFFF {
324+
let match_byte = mask_3.trailing_ones();
325+
return 144 + match_byte as usize;
326+
}
327+
328+
// 64 bytes (overlaps the previous 16 bytes for fast tail processing)
329+
let zmm_src0_4 = _mm512_loadu_si512(src0[192..].as_ptr().cast());
330+
let zmm_src1_4 = _mm512_loadu_si512(src1[192..].as_ptr().cast());
331+
let mask_4 = _mm512_cmpeq_epu8_mask(zmm_src0_4, zmm_src1_4);
332+
if mask_4 != 0xFFFFFFFFFFFFFFFF {
333+
let match_byte = mask_4.trailing_ones();
334+
return 192 + match_byte as usize;
335+
}
336+
}
337+
338+
256
339+
}
340+
341+
#[test]
342+
fn test_compare256() {
343+
if true {
344+
let str1 = [b'a'; super::MAX_COMPARE_SIZE];
345+
let mut str2 = [b'a'; super::MAX_COMPARE_SIZE];
346+
347+
for i in 0..str1.len() {
348+
str2[i] = 0;
349+
350+
let match_len = unsafe { compare256(&str1, &str2) };
351+
assert_eq!(match_len, i);
352+
353+
str2[i] = b'a';
354+
}
355+
}
356+
}
357+
}
358+
266359
#[cfg(target_arch = "wasm32")]
267360
mod wasm32 {
268361
use core::arch::wasm32::{u8x16_bitmask, u8x16_eq, v128, v128_load};

0 commit comments

Comments
 (0)