Skip to content

Commit dbab616

Browse files
committed
Move multiversioned functions outside of Searcher trait
1 parent b18a60d commit dbab616

File tree

5 files changed

+142
-110
lines changed

5 files changed

+142
-110
lines changed

src/aarch64.rs

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -227,23 +227,25 @@ impl<N: Needle> NeonSearcher<N> {
227227
#[inline]
228228
unsafe fn neon_2_search_in(&self, haystack: &[u8], end: usize) -> bool {
229229
let hash = VectorHash::<uint8x2_t>::from(&self.neon_half_hash);
230-
self.vector_search_in_neon_version(haystack, end, &hash)
230+
crate::vector_search_in_neon_version(self.needle(), self.position(), haystack, end, &hash)
231231
}
232232

233233
#[inline]
234234
unsafe fn neon_4_search_in(&self, haystack: &[u8], end: usize) -> bool {
235235
let hash = VectorHash::<uint8x4_t>::from(&self.neon_half_hash);
236-
self.vector_search_in_neon_version(haystack, end, &hash)
236+
crate::vector_search_in_neon_version(self.needle(), self.position(), haystack, end, &hash)
237237
}
238238

239239
#[inline]
240240
unsafe fn neon_8_search_in(&self, haystack: &[u8], end: usize) -> bool {
241-
self.vector_search_in_neon_version(haystack, end, &self.neon_half_hash)
241+
let hash = &self.neon_half_hash;
242+
crate::vector_search_in_neon_version(self.needle(), self.position(), haystack, end, hash)
242243
}
243244

244245
#[inline]
245246
unsafe fn neon_search_in(&self, haystack: &[u8], end: usize) -> bool {
246-
self.vector_search_in_neon_version(haystack, end, &self.neon_hash)
247+
let hash = &self.neon_hash;
248+
crate::vector_search_in_neon_version(self.needle(), self.position(), haystack, end, hash)
247249
}
248250

249251
/// Inlined version of `search_in` for hot call sites.

src/lib.rs

Lines changed: 91 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -112,11 +112,6 @@ trait NeedleWithSize: Needle {
112112
self.as_bytes().len()
113113
}
114114
}
115-
116-
#[inline]
117-
fn is_empty(&self) -> bool {
118-
self.size() == 0
119-
}
120115
}
121116

122117
impl<N: Needle + ?Sized> NeedleWithSize for N {}
@@ -192,96 +187,106 @@ impl<T: Vector, V: Vector + From<T>> From<&VectorHash<T>> for VectorHash<V> {
192187
}
193188
}
194189

195-
trait Searcher<N: NeedleWithSize + ?Sized> {
196-
fn needle(&self) -> &N;
190+
#[multiversion::multiversion]
191+
#[clone(target = "[x86|x86_64]+avx2")]
192+
#[clone(target = "wasm32+simd128")]
193+
#[clone(target = "aarch64+neon")]
194+
unsafe fn vector_search_in_chunk<N: NeedleWithSize + ?Sized, V: Vector>(
195+
needle: &N,
196+
position: usize,
197+
hash: &VectorHash<V>,
198+
start: *const u8,
199+
mask: u32,
200+
) -> bool {
201+
let first = V::load(start);
202+
let last = V::load(start.add(position));
203+
204+
let eq_first = V::lanes_eq(hash.first, first);
205+
let eq_last = V::lanes_eq(hash.last, last);
206+
207+
let eq = V::bitwise_and(eq_first, eq_last);
208+
let mut eq = V::to_bitmask(eq) & mask;
209+
210+
let chunk = start.add(1);
211+
let size = needle.size() - 1;
212+
let needle = needle.as_bytes().as_ptr().add(1);
213+
214+
while eq != 0 {
215+
let chunk = chunk.add(eq.trailing_zeros() as usize);
216+
let equal = match N::SIZE {
217+
Some(0) => unreachable!(),
218+
Some(1) => dispatch!(memcmp::specialized::<0>(chunk, needle)),
219+
Some(2) => dispatch!(memcmp::specialized::<1>(chunk, needle)),
220+
Some(3) => dispatch!(memcmp::specialized::<2>(chunk, needle)),
221+
Some(4) => dispatch!(memcmp::specialized::<3>(chunk, needle)),
222+
Some(5) => dispatch!(memcmp::specialized::<4>(chunk, needle)),
223+
Some(6) => dispatch!(memcmp::specialized::<5>(chunk, needle)),
224+
Some(7) => dispatch!(memcmp::specialized::<6>(chunk, needle)),
225+
Some(8) => dispatch!(memcmp::specialized::<7>(chunk, needle)),
226+
Some(9) => dispatch!(memcmp::specialized::<8>(chunk, needle)),
227+
Some(10) => dispatch!(memcmp::specialized::<9>(chunk, needle)),
228+
Some(11) => dispatch!(memcmp::specialized::<10>(chunk, needle)),
229+
Some(12) => dispatch!(memcmp::specialized::<11>(chunk, needle)),
230+
Some(13) => dispatch!(memcmp::specialized::<12>(chunk, needle)),
231+
Some(14) => dispatch!(memcmp::specialized::<13>(chunk, needle)),
232+
Some(15) => dispatch!(memcmp::specialized::<14>(chunk, needle)),
233+
Some(16) => dispatch!(memcmp::specialized::<15>(chunk, needle)),
234+
_ => dispatch!(memcmp::generic(chunk, needle, size)),
235+
};
236+
if equal {
237+
return true;
238+
}
197239

198-
fn position(&self) -> usize;
240+
eq = dispatch!(bits::clear_leftmost_set(eq));
241+
}
199242

200-
#[multiversion::multiversion]
201-
#[clone(target = "[x86|x86_64]+avx2")]
202-
#[clone(target = "wasm32+simd128")]
203-
#[clone(target = "aarch64+neon")]
204-
unsafe fn vector_search_in_chunk<V: Vector>(
205-
&self,
206-
hash: &VectorHash<V>,
207-
start: *const u8,
208-
mask: u32,
209-
) -> bool {
210-
let first = V::load(start);
211-
let last = V::load(start.add(self.position()));
212-
213-
let eq_first = V::lanes_eq(hash.first, first);
214-
let eq_last = V::lanes_eq(hash.last, last);
215-
216-
let eq = V::bitwise_and(eq_first, eq_last);
217-
let mut eq = V::to_bitmask(eq) & mask;
218-
219-
let chunk = start.add(1);
220-
let needle = self.needle().as_bytes().as_ptr().add(1);
221-
222-
while eq != 0 {
223-
let chunk = chunk.add(eq.trailing_zeros() as usize);
224-
let equal = match N::SIZE {
225-
Some(0) => unreachable!(),
226-
Some(1) => dispatch!(memcmp::specialized::<0>(chunk, needle)),
227-
Some(2) => dispatch!(memcmp::specialized::<1>(chunk, needle)),
228-
Some(3) => dispatch!(memcmp::specialized::<2>(chunk, needle)),
229-
Some(4) => dispatch!(memcmp::specialized::<3>(chunk, needle)),
230-
Some(5) => dispatch!(memcmp::specialized::<4>(chunk, needle)),
231-
Some(6) => dispatch!(memcmp::specialized::<5>(chunk, needle)),
232-
Some(7) => dispatch!(memcmp::specialized::<6>(chunk, needle)),
233-
Some(8) => dispatch!(memcmp::specialized::<7>(chunk, needle)),
234-
Some(9) => dispatch!(memcmp::specialized::<8>(chunk, needle)),
235-
Some(10) => dispatch!(memcmp::specialized::<9>(chunk, needle)),
236-
Some(11) => dispatch!(memcmp::specialized::<10>(chunk, needle)),
237-
Some(12) => dispatch!(memcmp::specialized::<11>(chunk, needle)),
238-
Some(13) => dispatch!(memcmp::specialized::<12>(chunk, needle)),
239-
Some(14) => dispatch!(memcmp::specialized::<13>(chunk, needle)),
240-
Some(15) => dispatch!(memcmp::specialized::<14>(chunk, needle)),
241-
Some(16) => dispatch!(memcmp::specialized::<15>(chunk, needle)),
242-
_ => dispatch!(memcmp::generic(chunk, needle, self.needle().size() - 1)),
243-
};
244-
if equal {
245-
return true;
246-
}
243+
false
244+
}
247245

248-
eq = dispatch!(bits::clear_leftmost_set(eq));
246+
#[allow(dead_code)]
247+
#[multiversion::multiversion]
248+
#[clone(target = "[x86|x86_64]+avx2")]
249+
#[clone(target = "wasm32+simd128")]
250+
#[clone(target = "aarch64+neon")]
251+
pub(crate) unsafe fn vector_search_in<N: NeedleWithSize + ?Sized, V: Vector>(
252+
needle: &N,
253+
position: usize,
254+
haystack: &[u8],
255+
end: usize,
256+
hash: &VectorHash<V>,
257+
) -> bool {
258+
debug_assert!(haystack.len() >= needle.size());
259+
260+
let mut chunks = haystack[..end].chunks_exact(V::LANES);
261+
for chunk in &mut chunks {
262+
if dispatch!(vector_search_in_chunk(
263+
needle,
264+
position,
265+
hash,
266+
chunk.as_ptr(),
267+
u32::MAX
268+
)) {
269+
return true;
249270
}
271+
}
250272

251-
false
252-
}
253-
254-
#[multiversion::multiversion]
255-
#[clone(target = "[x86|x86_64]+avx2")]
256-
#[clone(target = "wasm32+simd128")]
257-
#[clone(target = "aarch64+neon")]
258-
unsafe fn vector_search_in<V: Vector>(
259-
&self,
260-
haystack: &[u8],
261-
end: usize,
262-
hash: &VectorHash<V>,
263-
) -> bool {
264-
debug_assert!(haystack.len() >= self.needle().size());
265-
266-
let mut chunks = haystack[..end].chunks_exact(V::LANES);
267-
for chunk in &mut chunks {
268-
if dispatch!(self.vector_search_in_chunk(hash, chunk.as_ptr(), u32::MAX)) {
269-
return true;
270-
}
273+
let remainder = chunks.remainder().len();
274+
if remainder > 0 {
275+
let start = haystack.as_ptr().add(end - V::LANES);
276+
let mask = u32::MAX << (V::LANES - remainder);
277+
278+
if dispatch!(vector_search_in_chunk(needle, position, hash, start, mask)) {
279+
return true;
271280
}
281+
}
272282

273-
let remainder = chunks.remainder().len();
274-
if remainder > 0 {
275-
let start = haystack.as_ptr().add(end - V::LANES);
276-
let mask = u32::MAX << (V::LANES - remainder);
283+
false
284+
}
277285

278-
if dispatch!(self.vector_search_in_chunk(hash, start, mask)) {
279-
return true;
280-
}
281-
}
286+
trait Searcher<N: NeedleWithSize + ?Sized> {
287+
fn needle(&self) -> &N;
282288

283-
false
284-
}
289+
fn position(&self) -> usize;
285290
}
286291

287292
#[cfg(test)]

src/stdsimd.rs

Lines changed: 27 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -129,28 +129,47 @@ impl<N: Needle> StdSimdSearcher<N> {
129129
/// Inlined version of `search_in` for hot call sites.
130130
#[inline]
131131
pub fn inlined_search_in(&self, haystack: &[u8]) -> bool {
132-
if haystack.len() <= self.needle.size() {
133-
return haystack == self.needle.as_bytes();
132+
let needle = self.needle();
133+
134+
if haystack.len() <= needle.size() {
135+
return haystack == needle.as_bytes();
134136
}
135137

136-
let end = haystack.len() - self.needle.size() + 1;
138+
let position = self.position();
139+
let end = haystack.len() - needle.size() + 1;
137140

138141
if end < Simd2::LANES {
139142
unreachable!();
140143
} else if end < Simd4::LANES {
141144
let hash = from_hash::<32, 2>(&self.simd32_hash);
142-
unsafe { self.vector_search_in_default_version(haystack, end, &hash) }
145+
unsafe {
146+
crate::vector_search_in_default_version(needle, position, haystack, end, &hash)
147+
}
143148
} else if end < Simd8::LANES {
144149
let hash = from_hash::<32, 4>(&self.simd32_hash);
145-
unsafe { self.vector_search_in_default_version(haystack, end, &hash) }
150+
unsafe {
151+
crate::vector_search_in_default_version(needle, position, haystack, end, &hash)
152+
}
146153
} else if end < Simd16::LANES {
147154
let hash = from_hash::<32, 8>(&self.simd32_hash);
148-
unsafe { self.vector_search_in_default_version(haystack, end, &hash) }
155+
unsafe {
156+
crate::vector_search_in_default_version(needle, position, haystack, end, &hash)
157+
}
149158
} else if end < Simd32::LANES {
150159
let hash = from_hash::<32, 16>(&self.simd32_hash);
151-
unsafe { self.vector_search_in_default_version(haystack, end, &hash) }
160+
unsafe {
161+
crate::vector_search_in_default_version(needle, position, haystack, end, &hash)
162+
}
152163
} else {
153-
unsafe { self.vector_search_in_default_version(haystack, end, &self.simd32_hash) }
164+
unsafe {
165+
crate::vector_search_in_default_version(
166+
needle,
167+
position,
168+
haystack,
169+
end,
170+
&self.simd32_hash,
171+
)
172+
}
154173
}
155174
}
156175

src/wasm32.rs

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -242,25 +242,29 @@ impl<N: Needle> Wasm32Searcher<N> {
242242
#[inline]
243243
#[target_feature(enable = "simd128")]
244244
pub unsafe fn inlined_search_in(&self, haystack: &[u8]) -> bool {
245-
if haystack.len() <= self.needle.size() {
246-
return haystack == self.needle.as_bytes();
245+
let needle = self.needle();
246+
247+
if haystack.len() <= needle.size() {
248+
return haystack == needle.as_bytes();
247249
}
248250

249-
let end = haystack.len() - self.needle.size() + 1;
251+
let position = self.position();
252+
let end = haystack.len() - needle.size() + 1;
250253

251254
if end < v16::LANES {
252255
unreachable!();
253256
} else if end < v32::LANES {
254257
let hash = VectorHash::<v16>::from(&self.v128_hash);
255-
self.vector_search_in_simd128_version(haystack, end, &hash)
258+
crate::vector_search_in_simd128_version(needle, position, haystack, end, &hash)
256259
} else if end < v64::LANES {
257260
let hash = VectorHash::<v32>::from(&self.v128_hash);
258-
self.vector_search_in_simd128_version(haystack, end, &hash)
261+
crate::vector_search_in_simd128_version(needle, position, haystack, end, &hash)
259262
} else if end < v128::LANES {
260263
let hash = VectorHash::<v64>::from(&self.v128_hash);
261-
self.vector_search_in_simd128_version(haystack, end, &hash)
264+
crate::vector_search_in_simd128_version(needle, position, haystack, end, &hash)
262265
} else {
263-
self.vector_search_in_simd128_version(haystack, end, &self.v128_hash)
266+
let hash = &self.v128_hash;
267+
crate::vector_search_in_simd128_version(needle, position, haystack, end, hash)
264268
}
265269
}
266270

src/x86.rs

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -319,33 +319,35 @@ impl<N: Needle> Avx2Searcher<N> {
319319
#[target_feature(enable = "avx2")]
320320
unsafe fn sse2_2_search_in(&self, haystack: &[u8], end: usize) -> bool {
321321
let hash = VectorHash::<__m16i>::from(&self.sse2_hash);
322-
self.vector_search_in_avx2_version(haystack, end, &hash)
322+
crate::vector_search_in_avx2_version(self.needle(), self.position(), haystack, end, &hash)
323323
}
324324

325325
#[inline]
326326
#[target_feature(enable = "avx2")]
327327
unsafe fn sse2_4_search_in(&self, haystack: &[u8], end: usize) -> bool {
328328
let hash = VectorHash::<__m32i>::from(&self.sse2_hash);
329-
self.vector_search_in_avx2_version(haystack, end, &hash)
329+
crate::vector_search_in_avx2_version(self.needle(), self.position(), haystack, end, &hash)
330330
}
331331

332332
#[inline]
333333
#[target_feature(enable = "avx2")]
334334
unsafe fn sse2_8_search_in(&self, haystack: &[u8], end: usize) -> bool {
335335
let hash = VectorHash::<__m64i>::from(&self.sse2_hash);
336-
self.vector_search_in_avx2_version(haystack, end, &hash)
336+
crate::vector_search_in_avx2_version(self.needle(), self.position(), haystack, end, &hash)
337337
}
338338

339339
#[inline]
340340
#[target_feature(enable = "avx2")]
341341
unsafe fn sse2_16_search_in(&self, haystack: &[u8], end: usize) -> bool {
342-
self.vector_search_in_avx2_version(haystack, end, &self.sse2_hash)
342+
let hash = &self.sse2_hash;
343+
crate::vector_search_in_avx2_version(self.needle(), self.position(), haystack, end, hash)
343344
}
344345

345346
#[inline]
346347
#[target_feature(enable = "avx2")]
347348
unsafe fn avx2_search_in(&self, haystack: &[u8], end: usize) -> bool {
348-
self.vector_search_in_avx2_version(haystack, end, &self.avx2_hash)
349+
let hash = &self.avx2_hash;
350+
crate::vector_search_in_avx2_version(self.needle(), self.position(), haystack, end, hash)
349351
}
350352

351353
/// Inlined version of `search_in` for hot call sites.

0 commit comments

Comments
 (0)