Skip to content

Commit 533f146

Browse files
committed
Encode POD list data in a different fieldnum so that any SerializationContext can decode it.
This allows us to read data that was encoded with the withSerializePodListsInline flag on or off from any Serialization context.
1 parent f7834a1 commit 533f146

File tree

2 files changed

+65
-34
lines changed

2 files changed

+65
-34
lines changed

typed_python/TupleOrListOfType.hpp

Lines changed: 55 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -435,24 +435,27 @@ class TupleOrListOfType : public Type {
435435
buffer.writeBeginCompound(fieldNumber);
436436
}
437437

438-
buffer.writeUnsignedVarintObject(0, ct);
439438

440439
if (ct && m_element_type->isPOD() && buffer.getContext().serializePodListsInline()) {
441440
if (m_element_type->getTypeCategory() == TypeCategory::catInt64) {
441+
buffer.writeUnsignedVarintObject(1, ct);
442+
442443
serializeIntList(
443444
(int64_t*)this->eltPtr(self, 0),
444445
ct,
445446
buffer
446447
);
447448
} else {
448-
buffer.writeBeginBytes(0, m_element_type->bytecount() * ct);
449+
buffer.writeBeginBytes(2, m_element_type->bytecount() * ct);
449450

450451
buffer.write_bytes(
451452
this->eltPtr(self, 0),
452453
m_element_type->bytecount() * ct
453454
);
454455
}
455456
} else {
457+
buffer.writeUnsignedVarintObject(0, ct);
458+
456459
m_element_type->check([&](auto& concrete_type) {
457460
concrete_type.serializeMulti(
458461
this->eltPtr(self, 0),
@@ -496,47 +499,23 @@ class TupleOrListOfType : public Type {
496499
}
497500
}
498501

499-
size_t ct = buffer.readUnsignedVarintObject();
502+
auto fieldnumAndWireType = buffer.readFieldNumberAndWireType();
503+
size_t fieldnum = fieldnumAndWireType.first;
504+
size_t ct = buffer.readUnsignedVarint();
500505

501506
if (ct == 0) {
502507
constructor(self);
503508

509+
if (fieldnum != 0) {
510+
throw std::runtime_error("Corrupt field num - empty list/tuple count should be 0");
511+
}
512+
504513
if (isListOf()) {
505514
(*(layout**)self)->refcount++;
506515
buffer.addCachedPointer(id, *((layout**)self), this);
507516
}
508517
} else {
509-
if (m_element_type->isPOD() && buffer.getContext().serializePodListsInline()) {
510-
constructor(self, ct, [&](instance_ptr tgt, int k) {});
511-
512-
if (isListOf()) {
513-
(*(layout**)self)->refcount++;
514-
buffer.addCachedPointer(id, *((layout**)self), this);
515-
}
516-
if (m_element_type->getTypeCategory() == TypeCategory::catInt64) {
517-
deserializeIntList(
518-
(int64_t*)this->eltPtr(self, 0),
519-
ct,
520-
buffer
521-
);
522-
} else {
523-
auto fnAndWt = buffer.readFieldNumberAndWireType();
524-
size_t bytecount = buffer.readUnsignedVarint();
525-
526-
if (fnAndWt.second != WireType::BYTES) {
527-
throw std::runtime_error("Corrupt data (expected BYTES)");
528-
}
529-
530-
if (bytecount != m_element_type->bytecount() * ct) {
531-
throw std::runtime_error("Corrupt data (bytecount doesn't match)");
532-
}
533-
534-
buffer.read_bytes(
535-
this->eltPtr(self, 0),
536-
m_element_type->bytecount() * ct
537-
);
538-
}
539-
} else {
518+
if (fieldnum == 0) {
540519
constructor(self, ct, [&](instance_ptr tgt, int k) {
541520
if (k == 0 && isListOf()) {
542521
buffer.addCachedPointer(id, *((layout**)self), this);
@@ -553,6 +532,48 @@ class TupleOrListOfType : public Type {
553532

554533
m_element_type->deserialize(tgt, buffer, fieldAndWire.second);
555534
});
535+
} else
536+
if (fieldnum == 1) {
537+
if (m_element_type->getTypeCategory() != TypeCategory::catInt64) {
538+
throw std::runtime_error(
539+
"Compressed intArray data data makes no sense for " + m_element_type->name()
540+
);
541+
}
542+
constructor(self, ct, [&](instance_ptr tgt, int k) {});
543+
544+
if (isListOf()) {
545+
(*(layout**)self)->refcount++;
546+
buffer.addCachedPointer(id, *((layout**)self), this);
547+
}
548+
549+
deserializeIntList(
550+
(int64_t*)this->eltPtr(self, 0),
551+
ct,
552+
buffer
553+
);
554+
} else
555+
if (fieldnum == 2) {
556+
if (!m_element_type->isPOD()) {
557+
throw std::runtime_error(
558+
"Compressed POD data makes no sense for " + m_element_type->name()
559+
);
560+
}
561+
562+
size_t eltCount = ct / m_element_type->bytecount();
563+
if (eltCount * m_element_type->bytecount() != ct) {
564+
throw std::runtime_error("Invalid inline POD data - not a proper multiple");
565+
}
566+
567+
constructor(self, eltCount, [&](instance_ptr tgt, int k) {});
568+
569+
if (isListOf()) {
570+
(*(layout**)self)->refcount++;
571+
buffer.addCachedPointer(id, *((layout**)self), this);
572+
}
573+
574+
buffer.read_bytes(this->eltPtr(self, 0), ct);
575+
} else {
576+
throw std::runtime_error("Corrupt fieldnum for tuple/listof body");
556577
}
557578
}
558579

typed_python/types_serialization_test.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -401,6 +401,16 @@ def test_serialize_lists_compression_and_threads(self):
401401
print("nocompress is ", nocompressTime, int(size3 / 1024 / 1024))
402402
print("speedup is ", normalTime / threadTime)
403403

404+
def test_can_deserialize_pod_lists_with_any_context(self):
405+
someFloats = ListOf(float)(range(1000))
406+
someInts = ListOf(int)(range(1000))
407+
408+
s1 = SerializationContext().withSerializePodListsInline()
409+
s2 = SerializationContext()
410+
411+
assert s2.deserialize(s1.serialize(someFloats)) == someFloats
412+
assert s2.deserialize(s1.serialize(someInts)) == someInts
413+
404414
def test_serialize_core_python_objects(self):
405415
self.check_idempotence(0)
406416
self.check_idempotence(10)

0 commit comments

Comments
 (0)