Skip to content

Commit 20cf880

Browse files
committed
Parallelize the compression step in serialization.
We split across up to 4 threads for the LZ4 compression step. This can actually be faster than not compressing at all - maybe because it reduces pressure on the bus because the data is smaller?
1 parent a2cd9aa commit 20cf880

8 files changed

+231
-34
lines changed

typed_python/NullSerializationContext.hpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,9 @@ class NullSerializationContext : public SerializationContext {
3737
}
3838

3939

40+
virtual bool compressUsingThreads() const {
41+
return false;
42+
}
4043
virtual bool serializePodListsInline() const {
4144
return false;
4245
}

typed_python/PythonSerializationContext.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,14 @@ void PythonSerializationContext::setFlags() {
4646

4747
mSerializePodListsInline = ((PyObject*)serializePodListsInline) == Py_True;
4848

49+
PyObjectStealer compressUsingThreads(PyObject_GetAttrString(mContextObj, "compressUsingThreads"));
50+
51+
if (!compressUsingThreads) {
52+
throw PythonExceptionSet();
53+
}
54+
55+
mCompressUsingThreads = ((PyObject*)compressUsingThreads) == Py_True;
56+
4957
PyObjectStealer encodeLineInformationForCode(PyObject_GetAttrString(mContextObj, "encodeLineInformationForCode"));
5058

5159
if (!encodeLineInformationForCode) {

typed_python/PythonSerializationContext.hpp

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,8 @@ class PythonSerializationContext : public SerializationContext {
8484
mContextObj(typeSetObj),
8585
mCompressionEnabled(false),
8686
mSerializePodListsInline(false),
87-
mSerializeHashSequence(false)
87+
mSerializeHashSequence(false),
88+
mCompressUsingThreads(false)
8889
{
8990
setFlags();
9091
}
@@ -99,6 +100,10 @@ class PythonSerializationContext : public SerializationContext {
99100
return mSuppressLineInfo;
100101
}
101102

103+
bool compressUsingThreads() const {
104+
return mCompressUsingThreads;
105+
}
106+
102107
bool serializePodListsInline() const {
103108
return mSerializePodListsInline;
104109
}
@@ -196,6 +201,8 @@ class PythonSerializationContext : public SerializationContext {
196201

197202
bool mSerializePodListsInline;
198203

204+
bool mCompressUsingThreads;
205+
199206
bool mSuppressLineInfo;
200207

201208
bool mSerializeHashSequence;

typed_python/SerializationBuffer.cpp

Lines changed: 103 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,6 @@ void SerializationBufferBlock::compress() {
4444
size_t compressedBytecount;
4545

4646
{
47-
PyEnsureGilReleased releaseTheGil;
48-
4947
compressedBytecount = LZ4F_compressFrame(
5048
compressedBytes,
5149
bytesRequired,
@@ -76,12 +74,18 @@ void SerializationBufferBlock::compress() {
7674
}
7775

7876
void SerializationBuffer::consolidate() {
77+
if (m_is_consolidated) {
78+
return;
79+
}
80+
7981
if (m_wants_compress) {
8082
for (auto blockPtr: m_blocks) {
81-
blockPtr->compress();
83+
waitForCompression(blockPtr);
8284
}
8385
}
8486

87+
m_is_consolidated = true;
88+
8589
if (m_blocks.size() == 1) {
8690
return;
8791
}
@@ -110,3 +114,99 @@ void SerializationBuffer::consolidate() {
110114
)
111115
);
112116
}
117+
118+
std::shared_ptr<SerializationBufferBlock> SerializationBuffer::getNextCompressTask() {
119+
std::unique_lock<std::mutex> lock(s_compress_thread_mutex);
120+
121+
while (true) {
122+
if (s_waiting_compress_blocks.size()) {
123+
std::shared_ptr<SerializationBufferBlock> res =
124+
*s_waiting_compress_blocks.begin();
125+
126+
s_working_compress_blocks.insert(res);
127+
s_waiting_compress_blocks.erase(res);
128+
129+
return res;
130+
}
131+
132+
s_has_work->wait(lock);
133+
}
134+
}
135+
136+
void SerializationBuffer::compressionThread() {
137+
while (true) {
138+
std::shared_ptr<SerializationBufferBlock> task = getNextCompressTask();
139+
140+
task->compress();
141+
142+
std::unique_lock<std::mutex> lock(s_compress_thread_mutex);
143+
144+
s_working_compress_blocks.erase(task);
145+
146+
s_has_work->notify_all();
147+
}
148+
}
149+
150+
void SerializationBuffer::waitForCompression(std::shared_ptr<SerializationBufferBlock> block) {
151+
if (!m_compress_using_threads) {
152+
PyEnsureGilReleased releaseTheGil;
153+
block->compress();
154+
return;
155+
}
156+
157+
{
158+
std::unique_lock<std::mutex> lock(s_compress_thread_mutex);
159+
160+
while (true) {
161+
if (
162+
s_waiting_compress_blocks.find(block) == s_waiting_compress_blocks.end()
163+
&& s_working_compress_blocks.find(block) == s_working_compress_blocks.end()
164+
) {
165+
// we're done
166+
return;
167+
}
168+
169+
s_has_work->wait(lock);
170+
}
171+
}
172+
}
173+
174+
void SerializationBuffer::markForCompression(std::shared_ptr<SerializationBufferBlock> block) {
175+
if (!m_compress_using_threads) {
176+
return;
177+
}
178+
179+
{
180+
std::unique_lock<std::mutex> lock(s_compress_thread_mutex);
181+
182+
if (!s_compress_threads.size()) {
183+
for (long i = 0; i < 4; i++) {
184+
s_compress_threads.push_back(
185+
new std::thread(SerializationBuffer::compressionThread)
186+
);
187+
}
188+
189+
s_has_work = new std::condition_variable();
190+
}
191+
192+
s_waiting_compress_blocks.insert(block);
193+
s_has_work->notify_all();
194+
}
195+
}
196+
197+
// static
198+
std::mutex SerializationBuffer::s_compress_thread_mutex;
199+
200+
// static
201+
std::condition_variable* SerializationBuffer::s_has_work;
202+
203+
// static
204+
std::vector<std::thread*> SerializationBuffer::s_compress_threads;
205+
206+
// static
207+
std::unordered_set<std::shared_ptr<SerializationBufferBlock> > SerializationBuffer::s_waiting_compress_blocks;
208+
209+
// static
210+
std::unordered_set<std::shared_ptr<SerializationBufferBlock> > SerializationBuffer::s_working_compress_blocks;
211+
212+

typed_python/SerializationBuffer.hpp

Lines changed: 47 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020
#include <stdlib.h>
2121
#include <map>
2222
#include <set>
23+
#include <thread>
24+
#include <condition_variable>
2325
#include "Type.hpp"
2426
#include "WireType.hpp"
2527

@@ -115,18 +117,18 @@ class SerializationBufferBlock {
115117
size_t m_size;
116118
size_t m_reserved;
117119
uint8_t* m_buffer;
118-
119120
bool m_compressed;
120121
};
121122

122123
class SerializationBuffer {
123124
public:
124125
SerializationBuffer(const SerializationContext& context) :
125126
m_context(context),
126-
m_wants_compress(context.isCompressionEnabled())
127+
m_wants_compress(context.isCompressionEnabled()),
128+
m_compress_using_threads(context.compressUsingThreads()),
129+
m_is_consolidated(false)
127130
{
128131
m_top_block = new SerializationBufferBlock();
129-
130132
m_blocks.push_back(std::shared_ptr<SerializationBufferBlock>(m_top_block));
131133
}
132134

@@ -265,7 +267,9 @@ class SerializationBuffer {
265267
}
266268

267269
uint8_t* buffer() {
268-
consolidate();
270+
if (!m_is_consolidated) {
271+
consolidate();
272+
}
269273

270274
if (m_blocks.size() == 0) {
271275
return nullptr;
@@ -293,30 +297,23 @@ class SerializationBuffer {
293297

294298
//nakedly write bytes into the stream
295299
void write_bytes(uint8_t* ptr, size_t bytecount) {
296-
if (m_top_block->isCompressed() || m_top_block->oversized()) {
297-
m_top_block = new SerializationBufferBlock();
298-
m_blocks.push_back(
299-
std::shared_ptr<SerializationBufferBlock>(
300-
m_top_block
301-
)
302-
);
300+
while (bytecount > 1024 * 1024) {
301+
write_bytes(ptr, 1024 * 1024);
302+
303+
ptr += 1024 * 1024;
304+
bytecount -= 1024 * 1024;
303305
}
304306

307+
checkTopBlock();
308+
305309
m_top_block->write_bytes(ptr, bytecount);
306310
}
307311

308312
// allocate some memory and call 'c' with a uint8_t* pointing at it
309313
// to initialize it.
310314
template<class callback>
311315
void initialize_bytes(size_t bytecount, const callback& c) {
312-
if (m_top_block->isCompressed() || m_top_block->oversized()) {
313-
m_top_block = new SerializationBufferBlock();
314-
m_blocks.push_back(
315-
std::shared_ptr<SerializationBufferBlock>(
316-
m_top_block
317-
)
318-
);
319-
}
316+
checkTopBlock();
320317

321318
m_top_block->initialize_bytes(bytecount, c);
322319
}
@@ -405,24 +402,35 @@ class SerializationBuffer {
405402

406403
void finalize() {
407404
if (m_wants_compress) {
405+
markForCompression(m_blocks.back());
406+
408407
for (auto b: m_blocks) {
409-
b->compress();
408+
waitForCompression(b);
410409
}
411410
}
412411
}
413412

414413
template< class T>
415414
void write(T i) {
416-
if (m_top_block->isCompressed() || m_top_block->oversized()) {
415+
checkTopBlock();
416+
417+
m_top_block->write(i);
418+
}
419+
420+
void checkTopBlock() {
421+
if (m_top_block->oversized()) {
417422
m_top_block = new SerializationBufferBlock();
423+
424+
if (m_wants_compress) {
425+
markForCompression(m_blocks.back());
426+
}
427+
418428
m_blocks.push_back(
419429
std::shared_ptr<SerializationBufferBlock>(
420430
m_top_block
421431
)
422432
);
423433
}
424-
425-
m_top_block->write(i);
426434
}
427435

428436
void startSerializing(Type* nativeType) {
@@ -456,6 +464,10 @@ class SerializationBuffer {
456464

457465
bool m_wants_compress;
458466

467+
bool m_compress_using_threads;
468+
469+
bool m_is_consolidated;
470+
459471
size_t m_size;
460472

461473
// the
@@ -472,6 +484,18 @@ class SerializationBuffer {
472484
std::set<Type*> m_types_being_serialized;
473485

474486
std::unordered_map<MutuallyRecursiveTypeGroup*, int> m_group_counter;
487+
488+
static void compressionThread();
489+
static std::shared_ptr<SerializationBufferBlock> getNextCompressTask();
490+
491+
void markForCompression(std::shared_ptr<SerializationBufferBlock> block);
492+
void waitForCompression(std::shared_ptr<SerializationBufferBlock> block);
493+
494+
static std::mutex s_compress_thread_mutex;
495+
static std::condition_variable* s_has_work;
496+
static std::vector<std::thread*> s_compress_threads;
497+
static std::unordered_set<std::shared_ptr<SerializationBufferBlock> > s_waiting_compress_blocks;
498+
static std::unordered_set<std::shared_ptr<SerializationBufferBlock> > s_working_compress_blocks;
475499
};
476500

477501
class MarkTypeBeingSerialized {

typed_python/SerializationContext.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ class SerializationContext {
3434
virtual void serializeNativeType(Type* o, SerializationBuffer& b, size_t fieldNumber) const = 0;
3535
virtual Type* deserializeNativeType(DeserializationBuffer& b, size_t wireType) const = 0;
3636

37+
virtual bool compressUsingThreads() const = 0;
3738
virtual bool serializePodListsInline() const = 0;
3839
virtual bool isCompressionEnabled() const = 0;
3940
virtual bool isLineInfoSuppressed() const = 0;

0 commit comments

Comments
 (0)