@@ -9,6 +9,10 @@ pub fn compare256_slice(src0: &[u8], src1: &[u8]) -> usize {
99}
1010
1111fn compare256 ( src0 : & [ u8 ; 256 ] , src1 : & [ u8 ; 256 ] ) -> usize {
12+ if cfg ! ( target_feature = "avx512vl" ) && cfg ! ( target_feature = "avx512bw" ) {
13+ return unsafe { avx512:: compare256 ( src0, src1) } ;
14+ }
15+
1216 #[ cfg( target_arch = "x86_64" ) ]
1317 if crate :: cpu_features:: is_enabled_avx2_and_bmi2 ( ) {
1418 return unsafe { avx2:: compare256 ( src0, src1) } ;
@@ -263,6 +267,92 @@ mod avx2 {
263267 }
264268}
265269
270+ #[ cfg( target_arch = "x86_64" ) ]
271+ mod avx512 {
272+ use core:: arch:: x86_64:: {
273+ _mm512_cmpeq_epu8_mask, _mm512_loadu_si512, _mm_cmpeq_epu8_mask, _mm_loadu_si128,
274+ } ;
275+
276+ /// # Safety
277+ ///
278+ /// Behavior is undefined if the `avx` target feature is not enabled
279+ #[ target_feature( enable = "avx512vl" ) ]
280+ #[ target_feature( enable = "avx512bw" ) ]
281+ pub unsafe fn compare256 ( src0 : & [ u8 ; 256 ] , src1 : & [ u8 ; 256 ] ) -> usize {
282+ // First do a 16byte round before increasing to 64bytes, this reduces the
283+ // penalty for the short matches, and those are usually the most common ones.
284+ // This requires us to overlap on the last round, giving a small penalty
285+ // on matches of 192+ bytes (Still faster than AVX2 though).
286+
287+ unsafe {
288+ // 16 bytes
289+ let xmm_src0_0 = _mm_loadu_si128 ( src0. as_ptr ( ) . cast ( ) ) ;
290+ let xmm_src1_0 = _mm_loadu_si128 ( src1. as_ptr ( ) . cast ( ) ) ;
291+ let mask_0 = u32:: from ( _mm_cmpeq_epu8_mask ( xmm_src0_0, xmm_src1_0) ) ; // zero-extended to use __builtin_ctz
292+ if mask_0 != 0x0000FFFF {
293+ // There is potential for using __builtin_ctzg/__builtin_ctzs/_tzcnt_u16/__tzcnt_u16 here
294+ let match_byte = mask_0. trailing_ones ( ) ;
295+ return match_byte as usize ;
296+ }
297+
298+ // 64 bytes
299+ let zmm_src0_1 = _mm512_loadu_si512 ( src0[ 16 ..] . as_ptr ( ) . cast ( ) ) ;
300+ let zmm_src1_1 = _mm512_loadu_si512 ( src1[ 16 ..] . as_ptr ( ) . cast ( ) ) ;
301+ let mask_1 = _mm512_cmpeq_epu8_mask ( zmm_src0_1, zmm_src1_1) ;
302+ if mask_1 != 0xFFFFFFFFFFFFFFFF {
303+ let match_byte = mask_1. trailing_ones ( ) ;
304+ return 16 + match_byte as usize ;
305+ }
306+
307+ // 64 bytes
308+ let zmm_src0_2 = _mm512_loadu_si512 ( src0[ 80 ..] . as_ptr ( ) . cast ( ) ) ;
309+ let zmm_src1_2 = _mm512_loadu_si512 ( src1[ 80 ..] . as_ptr ( ) . cast ( ) ) ;
310+ let mask_2 = _mm512_cmpeq_epu8_mask ( zmm_src0_2, zmm_src1_2) ;
311+ if mask_2 != 0xFFFFFFFFFFFFFFFF {
312+ let match_byte = mask_2. trailing_ones ( ) ;
313+ return 80 + match_byte as usize ;
314+ }
315+
316+ // 64 bytes
317+ let zmm_src0_3 = _mm512_loadu_si512 ( src0[ 144 ..] . as_ptr ( ) . cast ( ) ) ;
318+ let zmm_src1_3 = _mm512_loadu_si512 ( src1[ 144 ..] . as_ptr ( ) . cast ( ) ) ;
319+ let mask_3 = _mm512_cmpeq_epu8_mask ( zmm_src0_3, zmm_src1_3) ;
320+ if mask_3 != 0xFFFFFFFFFFFFFFFF {
321+ let match_byte = mask_3. trailing_ones ( ) ;
322+ return 144 + match_byte as usize ;
323+ }
324+
325+ // 64 bytes (overlaps the previous 16 bytes for fast tail processing)
326+ let zmm_src0_4 = _mm512_loadu_si512 ( src0[ 192 ..] . as_ptr ( ) . cast ( ) ) ;
327+ let zmm_src1_4 = _mm512_loadu_si512 ( src1[ 192 ..] . as_ptr ( ) . cast ( ) ) ;
328+ let mask_4 = _mm512_cmpeq_epu8_mask ( zmm_src0_4, zmm_src1_4) ;
329+ if mask_4 != 0xFFFFFFFFFFFFFFFF {
330+ let match_byte = mask_4. trailing_ones ( ) ;
331+ return 192 + match_byte as usize ;
332+ }
333+ }
334+
335+ 256
336+ }
337+
338+ #[ test]
339+ fn test_compare256 ( ) {
340+ if true {
341+ let str1 = [ b'a' ; super :: MAX_COMPARE_SIZE ] ;
342+ let mut str2 = [ b'a' ; super :: MAX_COMPARE_SIZE ] ;
343+
344+ for i in 0 ..str1. len ( ) {
345+ str2[ i] = 0 ;
346+
347+ let match_len = unsafe { compare256 ( & str1, & str2) } ;
348+ assert_eq ! ( match_len, i) ;
349+
350+ str2[ i] = b'a' ;
351+ }
352+ }
353+ }
354+ }
355+
266356#[ cfg( target_arch = "wasm32" ) ]
267357mod wasm32 {
268358 use core:: arch:: wasm32:: { u8x16_bitmask, u8x16_eq, v128, v128_load} ;
0 commit comments