Skip to content

Commit 54e6334

Browse files
authored
Merge pull request #72 from thecppzoo/em/swar-demos
SWAR Demos - stage of converting eight ASCII bytes to int and string length, including AVX2 implementation.
2 parents e5ab9dc + bdde4a4 commit 54e6334

File tree

8 files changed

+275
-39
lines changed

8 files changed

+275
-39
lines changed

benchmark/atoi-corpus.h

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
#include "atoi.h"
2+
#include "zoo/pp/platform.h"
23

34
#include <vector>
45
#include <string>
@@ -119,17 +120,33 @@ struct CorpusStringLength {
119120
}
120121
};
121122

123+
#if ZOO_CONFIGURED_TO_USE_AVX()
124+
#define AVX2_STRLEN_CORPUS_X_LIST \
125+
X(ZOO_AVX, zoo::avx2_strlen)
126+
#else
127+
#define AVX2_STRLEN_CORPUS_X_LIST /* nothing */
128+
#endif
129+
130+
#if ZOO_CONFIGURED_TO_USE_NEON()
131+
#define NEON_STRLEN_CORPUS_X_LIST \
132+
X(ZOO_NEON, zoo::neon_strlen)
133+
#else
134+
#define NEON_STRLEN_CORPUS_X_LIST /* nothing */
135+
#endif
136+
137+
122138
#define STRLEN_CORPUS_X_LIST \
123139
X(LIBC_STRLEN, strlen) \
124140
X(ZOO_STRLEN, zoo::c_strLength) \
125141
X(ZOO_NATURAL_STRLEN, zoo::c_strLength_natural) \
126-
X(ZOO_MANUAL_STRLEN, zoo::c_strLength_manualComparison) \
127-
X(ZOO_AVX, zoo::avx2_strlen) \
128-
X(GENERIC_GLIBC_STRLEN, STRLEN_old)
142+
X(GENERIC_GLIBC_STRLEN, STRLEN_old) \
143+
AVX2_STRLEN_CORPUS_X_LIST \
144+
NEON_STRLEN_CORPUS_X_LIST
129145

130146
#define X(Typename, FunctionToCall) \
131147
struct Invoke##Typename { int operator()(const char *p) { return FunctionToCall(p); } };
132148

133149
PARSE8BYTES_CORPUS_X_LIST
134150
STRLEN_CORPUS_X_LIST
151+
135152
#undef X

benchmark/atoi.cpp

Lines changed: 183 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,17 @@
1-
#include "zoo/swar/SWAR.h"
1+
#include "atoi.h"
2+
23
#include "zoo/swar/associative_iteration.h"
34

5+
#if ZOO_CONFIGURED_TO_USE_AVX()
46
#include <immintrin.h>
7+
#endif
58

69
#include <stdint.h>
710
#include <string.h>
811
#include <stdlib.h>
12+
#include <ctype.h>
13+
14+
#include <tuple>
915

1016
// Copied from Daniel Lemire's GitHub at
1117
// https://lemire.me/blog/2018/10/03/quickly-parsing-eight-digits/
@@ -23,7 +29,7 @@ uint32_t parse_eight_digits_swar(const char *chars) {
2329

2430
// Note: eight digits can represent from 0 to (10^9) - 1, the logarithm base 2
2531
// of 10^9 is slightly less than 30, thus, only 30 bits are needed.
26-
auto lemire_as_zoo_swar(const char *chars) {
32+
uint32_t lemire_as_zoo_swar(const char *chars) {
2733
uint64_t bytes;
2834
memcpy(&bytes, chars, 8);
2935
auto allCharacterZero = zoo::meta::BitmaskMaker<uint64_t, '0', 8>::value;
@@ -52,16 +58,96 @@ auto lemire_as_zoo_swar(const char *chars) {
5258
return uint32_t(by10001base2to32.value() >> 32);
5359
}
5460

61+
std::size_t spaces_glibc(const char *ptr) {
62+
auto rv = 0;
63+
while(isspace(ptr[rv])) { ++rv; }
64+
return rv;
65+
}
66+
5567
namespace zoo {
5668

69+
//constexpr
70+
std::size_t leadingSpacesCount(swar::SWAR<8, uint64_t> bytes) noexcept {
71+
/*
72+
space (0x20, ' ')
73+
form feed (0x0c, '\f')
74+
line feed (0x0a, '\n')
75+
carriage return (0x0d, '\r')
76+
horizontal tab (0x09, '\t')
77+
vertical tab (0x0b, '\v')*
78+
constexpr std::array<char, 6> SpaceCharacters = {
79+
0b10'0000, //0x20 space
80+
0b00'1101, // 0xD \r
81+
0b00'1100, // 0xC \f
82+
0b00'1011, // 0xB \v
83+
0b00'1010, // 0xA \n
84+
0b00'1001 // 9 \t
85+
},
86+
ExpressedAsEscapeCodes = { ' ', '\r', '\f', '\v', '\n', '\t' };
87+
static_assert(SpaceCharacters == ExpressedAsEscapeCodes); */
88+
using S = swar::SWAR<8, uint64_t>;
89+
constexpr S Space{meta::BitmaskMaker<uint64_t, ' ', 8>::value};
90+
auto space = swar::equals(bytes, Space);
91+
auto otherWhiteSpace =
92+
swar::constantIsGreaterEqual<'\r'>(bytes) &
93+
~swar::constantIsGreaterEqual<'\t' - 1>(bytes);
94+
auto whiteSpace = space | otherWhiteSpace;
95+
auto notWhiteSpace = S{S::MostSignificantBit} ^ whiteSpace;
96+
return notWhiteSpace.lsbIndex();
97+
}
98+
99+
/// @brief Loads the "block" containing the pointer, by proper alignment
100+
/// @tparam PtrT Pointer type for loading
101+
/// @tparam Block as the name indicates
102+
/// @param pointerInsideBlock the potentially misaligned pointer
103+
/// @param b where the loaded bytes will be put
104+
/// @return a pair to indicate the aligned pointer to the base of the block
105+
/// and the misalignment, in bytes, of the source pointer
106+
template<typename PtrT, typename Block>
107+
std::tuple<PtrT *, int>
108+
blockAlignedLoad(PtrT *pointerInsideBlock, Block *b) {
109+
uintptr_t asUint = reinterpret_cast<uintptr_t>(pointerInsideBlock);
110+
constexpr auto Alignment = alignof(Block), Size = sizeof(Block);
111+
static_assert(Alignment == Size);
112+
auto misalignment = asUint % Alignment;
113+
auto *base = reinterpret_cast<PtrT *>(asUint - misalignment);
114+
memcpy(b, base, Size);
115+
return { base, misalignment };
116+
}
117+
118+
/// \brief Helper function to fix the non-string part of block
119+
template<typename S>
120+
S adjustMisalignmentFor_strlen(S data, int misalignment) {
121+
// The speculative load has the valid data in the higher lanes.
122+
// To use the same algorithm as the rest of the implementation, simply
123+
// populate with ones the lower part, in that way there won't be nulls.
124+
constexpr typename S::type Zero{0};
125+
auto
126+
zeroesInMisalignedOnesInValid =
127+
(~Zero) // all ones
128+
<< (misalignment * 8), // assumes 8 bits per char
129+
onesInMisalignedZeroesInValid = ~zeroesInMisalignedOnesInValid;
130+
return data | S{onesInMisalignedZeroesInValid};
131+
}
132+
57133
std::size_t c_strLength(const char *s) {
58-
using S = swar::SWAR<8, std::size_t>;
134+
using S = swar::SWAR<8, uint64_t>;
59135
constexpr auto
60136
MSBs = S{S::MostSignificantBit},
61137
Ones = S{S::LeastSignificantBit};
62-
S bytes;
63-
for(auto base = s;; base += 8) {
64-
memcpy(&bytes.m_v, base, 8);
138+
constexpr auto BytesPerIteration = sizeof(S::type);
139+
S initialBytes;
140+
141+
auto indexOfFirstTrue = [](auto bs) { return bs.lsbIndex(); };
142+
143+
// Misalignment must be taken into account because a SWAR read is
144+
// speculative, it might read bytes outside of the actual string.
145+
// It is safe to read within the page where the string occurs, and to
146+
// guarantee that, simply make aligned reads because the size of the SWAR
147+
// base size will always divide the memory page size
148+
auto [alignedBase, misalignment] = blockAlignedLoad(s, &initialBytes);
149+
auto bytes = adjustMisalignmentFor_strlen(initialBytes, misalignment);
150+
for(;;) {
65151
auto firstNullTurnsOnMSB = bytes - Ones;
66152
// The first lane with a null will borrow and set its MSB on when
67153
// subtracted one.
@@ -74,24 +160,28 @@ std::size_t c_strLength(const char *s) {
74160
auto cheapestInversionOfMSBs = ~bytes;
75161
auto firstMSBsOnIsFirstNull =
76162
firstNullTurnsOnMSB & cheapestInversionOfMSBs;
77-
auto onlyMSBs = zoo::swar::convertToBooleanSWAR(firstMSBsOnIsFirstNull);
78-
if(onlyMSBs) { // there is a null!
79-
auto firstNullIndex = onlyMSBs.lsbIndex();
80-
return firstNullIndex + (base - s);
163+
auto onlyMSBs = swar::convertToBooleanSWAR(firstMSBsOnIsFirstNull);
164+
if(onlyMSBs) {
165+
return alignedBase + indexOfFirstTrue(onlyMSBs) - s;
81166
}
167+
alignedBase += BytesPerIteration;
168+
memcpy(&bytes, alignedBase, BytesPerIteration);
82169
}
83170
}
84171

85172
std::size_t c_strLength_natural(const char *s) {
86-
using S = swar::SWAR<8, std::size_t>;
87-
S bytes;
88-
for(auto base = s;; base += 8) {
89-
memcpy(&bytes.m_v, base, 8);
173+
using S = swar::SWAR<8, std::uint64_t>;
174+
S initialBytes;
175+
auto [base, misalignment] = blockAlignedLoad(s, &initialBytes);
176+
auto bytes = adjustMisalignmentFor_strlen(initialBytes, misalignment);
177+
for(;;) {
90178
auto nulls = zoo::swar::equals(bytes, S{0});
91179
if(nulls) { // there is a null!
92180
auto firstNullIndex = nulls.lsbIndex();
93-
return firstNullIndex + (base - s);
181+
return firstNullIndex + base - s;
94182
}
183+
base += sizeof(S);
184+
memcpy(&bytes.m_v, base, 8);
95185
}
96186
}
97187

@@ -117,29 +207,47 @@ std::size_t c_strLength_manualComparison(const char *s) {
117207
}
118208
}
119209

210+
#if ZOO_CONFIGURED_TO_USE_AVX()
211+
212+
/// \note Partially generated by Chat GPT 4
120213
size_t avx2_strlen(const char* str) {
121214
const __m256i zero = _mm256_setzero_si256(); // Vector of 32 zero bytes
122215
size_t offset = 0;
216+
__m256i data;
217+
auto [alignedBase, misalignment] = blockAlignedLoad(str, &data);
123218

124-
// Loop over the string in blocks of 32 bytes
125-
for (;; offset += 32) {
126-
// Load 32 bytes of the string into a __m256i vector
127-
__m256i data;// = _mm256_load_si256((const __m256i*)(str + offset));
128-
memcpy(&data, str + offset, 32);
129-
// Compare each byte with '\0'
130-
__m256i cmp = _mm256_cmpeq_epi8(data, zero);
131-
// Create a mask indicating which bytes are '\0'
132-
int mask = _mm256_movemask_epi8(cmp);
219+
// AVX does not offer a practical way to generate a mask of all ones in
220+
// the least significant positions, thus we cant invoke adjustFor_strlen.
221+
// We will do the first iteration as a special case to explicitly take into
222+
// account misalignment
223+
auto maskOfMask = (~uint64_t(0)) << misalignment;
133224

225+
auto compareAndMask =
226+
[&]() {
227+
// Compare each byte with '\0'
228+
__m256i cmp = _mm256_cmpeq_epi8(data, zero);
229+
// Create a mask indicating which bytes are '\0'
230+
return _mm256_movemask_epi8(cmp);
231+
};
232+
auto mask = compareAndMask();
233+
mask &= maskOfMask;
234+
235+
// Loop over the string in blocks of 32 bytes
236+
for (;;) {
134237
// If mask is not zero, we found a '\0' byte
135238
if (mask) {
136-
// Calculate the index of the first '\0' byte using ctz (Count Trailing Zeros)
137-
return offset + __builtin_ctz(mask);
239+
// Calculate the index of the first '\0' byte counting trailing 0s
240+
auto nunNullByteCount = __builtin_ctz(mask);
241+
return alignedBase + offset + nunNullByteCount - str;
138242
}
243+
offset += 32;
244+
memcpy(&data, alignedBase + offset, 32);
245+
mask = compareAndMask();
139246
}
140247
// Unreachable, but included to avoid compiler warnings
141248
return offset;
142249
}
250+
#endif
143251

144252
}
145253

@@ -217,3 +325,52 @@ STRLEN_old (const char *str)
217325
}
218326
}
219327
}
328+
329+
330+
#if ZOO_CONFIGURED_TO_USE_NEON()
331+
332+
#include <arm_neon.h>
333+
334+
namespace zoo {
335+
336+
/// \note uses the key technique of shifting by 4 and narrowing from 16 to 8 bit lanes in
337+
/// aarch64/strlen.S at
338+
/// https://sourceware.org/git/?p=glibc.git;a=blob;f=sysdeps/aarch64/strlen.S;h=ab2a576cdb5665e596b791299af3f4abecb73c0e;hb=HEAD
339+
std::size_t neon_strlen(const char *str) {
340+
const uint8x16_t zero = vdupq_n_u8(0);
341+
size_t offset = 0;
342+
uint8x16_t data;
343+
auto [alignedBase, misalignment] = blockAlignedLoad(str, &data);
344+
345+
auto compareAndConvertResultsToNibbles = [&]() {
346+
auto cmp = vceqq_u8(data, zero);
347+
// The result looks like, in hexadecimal digits, like this:
348+
// [ AA, BB, CC, DD, EE, FF, GG, HH, ... ] with each
349+
// variable A, B, ... either 0xF or 0x0.
350+
// instead of 16x8 bit results, we can see that as
351+
// 8 16 bit results like this
352+
// [ AABB, CCDD, EEFF, GGHH, ... ]
353+
// If we shift out a nibble from each element (shift right by 4):
354+
// [ ABB0, CDD0, EFF0, GHH0, ... ]
355+
// Narrowing from 16 to eight, we would get
356+
// [ AB, CD, EF, GH, ... ]
357+
auto straddle8bitLanePairAndNarrowToBytes = vshrn_n_u16(cmp, 4);
358+
return vget_lane_u64(vreinterpret_u64_u8(straddle8bitLanePairAndNarrowToBytes), 0);
359+
};
360+
auto nibbles = compareAndConvertResultsToNibbles();
361+
auto misalignmentNibbleMask = (~uint64_t(0)) << (misalignment * 4);
362+
nibbles &= misalignmentNibbleMask;
363+
for(;;) {
364+
if(nibbles) {
365+
auto trailingZeroBits = __builtin_ctz(nibbles);
366+
auto nonNullByteCount = trailingZeroBits / 4;
367+
return alignedBase + offset + nonNullByteCount - str;
368+
}
369+
alignedBase += sizeof(uint8x16_t);
370+
memcpy(&data, alignedBase, sizeof(uint8x16_t));
371+
nibbles = compareAndConvertResultsToNibbles();
372+
}
373+
}
374+
375+
}
376+
#endif

benchmark/atoi.h

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,26 @@
1-
#include "stdint.h"
1+
#include "zoo/swar/SWAR.h"
2+
#include "zoo/pp/platform.h"
3+
24
#include <cstdlib>
35

46
uint32_t parse_eight_digits_swar(const char *chars);
57
uint32_t lemire_as_zoo_swar(const char *chars);
68

9+
std::size_t spaces_glibc(const char *ptr);
10+
711
namespace zoo {
812

13+
std::size_t leadingSpacesCount(swar::SWAR<8, uint64_t> bytes) noexcept;
914
std::size_t c_strLength(const char *s);
1015
std::size_t c_strLength_natural(const char *s);
11-
std::size_t c_strLength_manualComparison(const char *s);
16+
17+
#if ZOO_CONFIGURED_TO_USE_AVX()
1218
std::size_t avx2_strlen(const char* str);
19+
#endif
20+
21+
#if ZOO_CONFIGURED_TO_USE_NEON()
22+
std::size_t neon_strlen(const char* str);
23+
#endif
1324

1425
}
1526

benchmark/catch2swar-demo.cpp

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,22 @@ TEST_CASE("Atoi benchmarks", "[atoi][swar]") {
1717
auto seed = rd();
1818
CAPTURE(seed);
1919
std::mt19937 g(seed);
20+
SECTION("Simple comparison of two strings") {
21+
auto TwoStrings = "Str1\0Much longer string here, even for AVX2";
22+
auto zoolength1 = zoo::c_strLength(TwoStrings);
23+
auto strlen1 = strlen(TwoStrings);
24+
REQUIRE(zoolength1 == strlen1);
25+
auto skipFst = TwoStrings + strlen1 + 1;
26+
auto zl2 = zoo::c_strLength(skipFst);
27+
auto strlen2 = strlen(skipFst);
28+
REQUIRE(zl2 == strlen2);
29+
#if ZOO_CONFIGURED_TO_USE_AVX()
30+
auto avx1 = zoo::avx2_strlen(TwoStrings);
31+
REQUIRE(avx1 == strlen1);
32+
auto avx2 = zoo::avx2_strlen(skipFst);
33+
REQUIRE(avx2 == strlen2);
34+
#endif
35+
}
2036
auto corpus8D = Corpus8DecimalDigits::makeCorpus(g);
2137
auto corpusStrlen = CorpusStringLength::makeCorpus(g);
2238
#define X(Type, Fun) \
@@ -34,9 +50,10 @@ TEST_CASE("Atoi benchmarks", "[atoi][swar]") {
3450
REQUIRE(fromLIBC == fromZoo);
3551
REQUIRE(fromZOO_STRLEN == fromLIBC_STRLEN);
3652
REQUIRE(fromLIBC_STRLEN == fromZOO_NATURAL_STRLEN);
37-
REQUIRE(fromZOO_NATURAL_STRLEN == fromZOO_MANUAL_STRLEN);
3853
REQUIRE(fromGENERIC_GLIBC_STRLEN == fromZOO_NATURAL_STRLEN);
39-
REQUIRE(fromZOO_AVX == fromZOO_STRLEN);
54+
#if ZOO_CONFIGURED_TO_USE_AVX()
55+
REQUIRE(fromZOO_AVX == fromZOO_STRLEN);
56+
#endif
4057

4158
auto haveTheRoleOfMemoryBarrier = -1;
4259
#define X(Type, Fun) \

0 commit comments

Comments
 (0)