Skip to content

Commit fde4111

Browse files
committed
port avx512 implementation of compare256
1 parent c7e9545 commit fde4111

File tree

2 files changed

+124
-6
lines changed

2 files changed

+124
-6
lines changed

.github/workflows/checks.yaml

Lines changed: 31 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -488,6 +488,34 @@ jobs:
488488
run: "cargo +nightly miri nextest run -j4 -p test-libz-rs-sys --target ${{ matrix.target }} null::"
489489
env:
490490
RUSTFLAGS: "-Ctarget-feature=+avx2,+bmi2,+bmi1"
491+
- name: Test allocator with miri
492+
run: "cargo +nightly miri nextest run -j4 -p zlib-rs --target ${{ matrix.target }} allocate::"
493+
- name: Test gz logic with miri
494+
working-directory: libz-rs-sys-cdylib
495+
run: "cargo +nightly miri nextest run -j4 -p libz-rs-sys-cdylib --target ${{ matrix.target }} --features=gz"
496+
env:
497+
MIRIFLAGS: "-Zmiri-tree-borrows -Zmiri-disable-isolation"
498+
499+
miri-avx512:
500+
name: "Miri avx512"
501+
runs-on: ubuntu-latest
502+
strategy:
503+
matrix:
504+
include:
505+
- target: x86_64-unknown-linux-gnu
506+
env:
507+
QUICKCHECK_TESTS: 10
508+
steps:
509+
- uses: actions/checkout@v3
510+
- name: Install Miri
511+
run: |
512+
rustup target add ${{ matrix.target }}
513+
rustup toolchain install nightly --component miri
514+
cargo +nightly miri setup
515+
- name: Install cargo-nextest
516+
uses: taiki-e/install-action@d12e869b89167df346dd0ff65da342d1fb1202fb # v2.53.2
517+
with:
518+
tool: cargo-nextest
491519
- name: Test avx512 crc32 implementation
492520
run: "cargo +nightly miri nextest run -j4 -p zlib-rs --target ${{ matrix.target }} --features=vpclmulqdq crc32::"
493521
env:
@@ -496,13 +524,10 @@ jobs:
496524
run: "cargo +nightly miri nextest run -j4 -p zlib-rs --target ${{ matrix.target }} --features=avx512 adler32::"
497525
env:
498526
RUSTFLAGS: "-Ctarget-feature=+avx2,+bmi2,+bmi1,+avx512f,+avx512bw"
499-
- name: Test allocator with miri
500-
run: "cargo +nightly miri nextest run -j4 -p zlib-rs --target ${{ matrix.target }} allocate::"
501-
- name: Test gz logic with miri
502-
working-directory: libz-rs-sys-cdylib
503-
run: "cargo +nightly miri nextest run -j4 -p libz-rs-sys-cdylib --target ${{ matrix.target }} --features=gz"
527+
- name: Test avx512 compare256 implementation
528+
run: "cargo +nightly miri nextest run -j4 -p zlib-rs --target ${{ matrix.target }} --features=avx512 compare256"
504529
env:
505-
MIRIFLAGS: "-Zmiri-tree-borrows -Zmiri-disable-isolation"
530+
RUSTFLAGS: "-Ctarget-feature=+avx512vl,+avx512bw"
506531

507532
run-flate2-test-suite:
508533
name: run flate2 test suite

zlib-rs/src/deflate/compare256.rs

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,12 @@ pub fn compare256_slice(src0: &[u8], src1: &[u8]) -> usize {
99
}
1010

1111
fn compare256(src0: &[u8; 256], src1: &[u8; 256]) -> usize {
12+
#[cfg(feature = "avx512")]
13+
#[cfg(target_arch = "x86_64")]
14+
if cfg!(target_feature = "avx512vl") && cfg!(target_feature = "avx512bw") {
15+
return unsafe { avx512::compare256(src0, src1) };
16+
}
17+
1218
#[cfg(target_arch = "x86_64")]
1319
if crate::cpu_features::is_enabled_avx2_and_bmi2() {
1420
return unsafe { avx2::compare256(src0, src1) };
@@ -263,6 +269,93 @@ mod avx2 {
263269
}
264270
}
265271

272+
#[cfg(feature = "avx512")]
273+
#[cfg(target_arch = "x86_64")]
274+
mod avx512 {
275+
use core::arch::x86_64::{
276+
_mm512_cmpeq_epu8_mask, _mm512_loadu_si512, _mm_cmpeq_epu8_mask, _mm_loadu_si128,
277+
};
278+
279+
/// # Safety
280+
///
281+
/// Behavior is undefined if the `avx` target feature is not enabled
282+
#[target_feature(enable = "avx512vl")]
283+
#[target_feature(enable = "avx512bw")]
284+
pub unsafe fn compare256(src0: &[u8; 256], src1: &[u8; 256]) -> usize {
285+
// First do a 16byte round before increasing to 64bytes, this reduces the
286+
// penalty for the short matches, and those are usually the most common ones.
287+
// This requires us to overlap on the last round, giving a small penalty
288+
// on matches of 192+ bytes (Still faster than AVX2 though).
289+
290+
unsafe {
291+
// 16 bytes
292+
let xmm_src0_0 = _mm_loadu_si128(src0.as_ptr().cast());
293+
let xmm_src1_0 = _mm_loadu_si128(src1.as_ptr().cast());
294+
let mask_0 = u32::from(_mm_cmpeq_epu8_mask(xmm_src0_0, xmm_src1_0)); // zero-extended to use __builtin_ctz
295+
if mask_0 != 0x0000FFFF {
296+
// There is potential for using __builtin_ctzg/__builtin_ctzs/_tzcnt_u16/__tzcnt_u16 here
297+
let match_byte = mask_0.trailing_ones();
298+
return match_byte as usize;
299+
}
300+
301+
// 64 bytes
302+
let zmm_src0_1 = _mm512_loadu_si512(src0[16..].as_ptr().cast());
303+
let zmm_src1_1 = _mm512_loadu_si512(src1[16..].as_ptr().cast());
304+
let mask_1 = _mm512_cmpeq_epu8_mask(zmm_src0_1, zmm_src1_1);
305+
if mask_1 != 0xFFFFFFFFFFFFFFFF {
306+
let match_byte = mask_1.trailing_ones();
307+
return 16 + match_byte as usize;
308+
}
309+
310+
// 64 bytes
311+
let zmm_src0_2 = _mm512_loadu_si512(src0[80..].as_ptr().cast());
312+
let zmm_src1_2 = _mm512_loadu_si512(src1[80..].as_ptr().cast());
313+
let mask_2 = _mm512_cmpeq_epu8_mask(zmm_src0_2, zmm_src1_2);
314+
if mask_2 != 0xFFFFFFFFFFFFFFFF {
315+
let match_byte = mask_2.trailing_ones();
316+
return 80 + match_byte as usize;
317+
}
318+
319+
// 64 bytes
320+
let zmm_src0_3 = _mm512_loadu_si512(src0[144..].as_ptr().cast());
321+
let zmm_src1_3 = _mm512_loadu_si512(src1[144..].as_ptr().cast());
322+
let mask_3 = _mm512_cmpeq_epu8_mask(zmm_src0_3, zmm_src1_3);
323+
if mask_3 != 0xFFFFFFFFFFFFFFFF {
324+
let match_byte = mask_3.trailing_ones();
325+
return 144 + match_byte as usize;
326+
}
327+
328+
// 64 bytes (overlaps the previous 16 bytes for fast tail processing)
329+
let zmm_src0_4 = _mm512_loadu_si512(src0[192..].as_ptr().cast());
330+
let zmm_src1_4 = _mm512_loadu_si512(src1[192..].as_ptr().cast());
331+
let mask_4 = _mm512_cmpeq_epu8_mask(zmm_src0_4, zmm_src1_4);
332+
if mask_4 != 0xFFFFFFFFFFFFFFFF {
333+
let match_byte = mask_4.trailing_ones();
334+
return 192 + match_byte as usize;
335+
}
336+
}
337+
338+
256
339+
}
340+
341+
#[test]
342+
fn test_compare256() {
343+
if true {
344+
let str1 = [b'a'; super::MAX_COMPARE_SIZE];
345+
let mut str2 = [b'a'; super::MAX_COMPARE_SIZE];
346+
347+
for i in 0..str1.len() {
348+
str2[i] = 0;
349+
350+
let match_len = unsafe { compare256(&str1, &str2) };
351+
assert_eq!(match_len, i);
352+
353+
str2[i] = b'a';
354+
}
355+
}
356+
}
357+
}
358+
266359
#[cfg(target_arch = "wasm32")]
267360
mod wasm32 {
268361
use core::arch::wasm32::{u8x16_bitmask, u8x16_eq, v128, v128_load};

0 commit comments

Comments
 (0)