@@ -9,6 +9,12 @@ pub fn compare256_slice(src0: &[u8], src1: &[u8]) -> usize {
99}
1010
1111fn compare256 ( src0 : & [ u8 ; 256 ] , src1 : & [ u8 ; 256 ] ) -> usize {
12+ #[ cfg( feature = "avx512" ) ]
13+ #[ cfg( target_arch = "x86_64" ) ]
14+ if cfg ! ( target_feature = "avx512vl" ) && cfg ! ( target_feature = "avx512bw" ) {
15+ return unsafe { avx512:: compare256 ( src0, src1) } ;
16+ }
17+
1218 #[ cfg( target_arch = "x86_64" ) ]
1319 if crate :: cpu_features:: is_enabled_avx2_and_bmi2 ( ) {
1420 return unsafe { avx2:: compare256 ( src0, src1) } ;
@@ -263,6 +269,93 @@ mod avx2 {
263269 }
264270}
265271
272+ #[ cfg( feature = "avx512" ) ]
273+ #[ cfg( target_arch = "x86_64" ) ]
274+ mod avx512 {
275+ use core:: arch:: x86_64:: {
276+ _mm512_cmpeq_epu8_mask, _mm512_loadu_si512, _mm_cmpeq_epu8_mask, _mm_loadu_si128,
277+ } ;
278+
279+ /// # Safety
280+ ///
281+ /// Behavior is undefined if the `avx` target feature is not enabled
282+ #[ target_feature( enable = "avx512vl" ) ]
283+ #[ target_feature( enable = "avx512bw" ) ]
284+ pub unsafe fn compare256 ( src0 : & [ u8 ; 256 ] , src1 : & [ u8 ; 256 ] ) -> usize {
285+ // First do a 16byte round before increasing to 64bytes, this reduces the
286+ // penalty for the short matches, and those are usually the most common ones.
287+ // This requires us to overlap on the last round, giving a small penalty
288+ // on matches of 192+ bytes (Still faster than AVX2 though).
289+
290+ unsafe {
291+ // 16 bytes
292+ let xmm_src0_0 = _mm_loadu_si128 ( src0. as_ptr ( ) . cast ( ) ) ;
293+ let xmm_src1_0 = _mm_loadu_si128 ( src1. as_ptr ( ) . cast ( ) ) ;
294+ let mask_0 = u32:: from ( _mm_cmpeq_epu8_mask ( xmm_src0_0, xmm_src1_0) ) ; // zero-extended to use __builtin_ctz
295+ if mask_0 != 0x0000FFFF {
296+ // There is potential for using __builtin_ctzg/__builtin_ctzs/_tzcnt_u16/__tzcnt_u16 here
297+ let match_byte = mask_0. trailing_ones ( ) ;
298+ return match_byte as usize ;
299+ }
300+
301+ // 64 bytes
302+ let zmm_src0_1 = _mm512_loadu_si512 ( src0[ 16 ..] . as_ptr ( ) . cast ( ) ) ;
303+ let zmm_src1_1 = _mm512_loadu_si512 ( src1[ 16 ..] . as_ptr ( ) . cast ( ) ) ;
304+ let mask_1 = _mm512_cmpeq_epu8_mask ( zmm_src0_1, zmm_src1_1) ;
305+ if mask_1 != 0xFFFFFFFFFFFFFFFF {
306+ let match_byte = mask_1. trailing_ones ( ) ;
307+ return 16 + match_byte as usize ;
308+ }
309+
310+ // 64 bytes
311+ let zmm_src0_2 = _mm512_loadu_si512 ( src0[ 80 ..] . as_ptr ( ) . cast ( ) ) ;
312+ let zmm_src1_2 = _mm512_loadu_si512 ( src1[ 80 ..] . as_ptr ( ) . cast ( ) ) ;
313+ let mask_2 = _mm512_cmpeq_epu8_mask ( zmm_src0_2, zmm_src1_2) ;
314+ if mask_2 != 0xFFFFFFFFFFFFFFFF {
315+ let match_byte = mask_2. trailing_ones ( ) ;
316+ return 80 + match_byte as usize ;
317+ }
318+
319+ // 64 bytes
320+ let zmm_src0_3 = _mm512_loadu_si512 ( src0[ 144 ..] . as_ptr ( ) . cast ( ) ) ;
321+ let zmm_src1_3 = _mm512_loadu_si512 ( src1[ 144 ..] . as_ptr ( ) . cast ( ) ) ;
322+ let mask_3 = _mm512_cmpeq_epu8_mask ( zmm_src0_3, zmm_src1_3) ;
323+ if mask_3 != 0xFFFFFFFFFFFFFFFF {
324+ let match_byte = mask_3. trailing_ones ( ) ;
325+ return 144 + match_byte as usize ;
326+ }
327+
328+ // 64 bytes (overlaps the previous 16 bytes for fast tail processing)
329+ let zmm_src0_4 = _mm512_loadu_si512 ( src0[ 192 ..] . as_ptr ( ) . cast ( ) ) ;
330+ let zmm_src1_4 = _mm512_loadu_si512 ( src1[ 192 ..] . as_ptr ( ) . cast ( ) ) ;
331+ let mask_4 = _mm512_cmpeq_epu8_mask ( zmm_src0_4, zmm_src1_4) ;
332+ if mask_4 != 0xFFFFFFFFFFFFFFFF {
333+ let match_byte = mask_4. trailing_ones ( ) ;
334+ return 192 + match_byte as usize ;
335+ }
336+ }
337+
338+ 256
339+ }
340+
341+ #[ test]
342+ fn test_compare256 ( ) {
343+ if true {
344+ let str1 = [ b'a' ; super :: MAX_COMPARE_SIZE ] ;
345+ let mut str2 = [ b'a' ; super :: MAX_COMPARE_SIZE ] ;
346+
347+ for i in 0 ..str1. len ( ) {
348+ str2[ i] = 0 ;
349+
350+ let match_len = unsafe { compare256 ( & str1, & str2) } ;
351+ assert_eq ! ( match_len, i) ;
352+
353+ str2[ i] = b'a' ;
354+ }
355+ }
356+ }
357+ }
358+
266359#[ cfg( target_arch = "wasm32" ) ]
267360mod wasm32 {
268361 use core:: arch:: wasm32:: { u8x16_bitmask, u8x16_eq, v128, v128_load} ;
0 commit comments