From a8611d4f63f2b9e2de6fd76758576ae7fd44de4e Mon Sep 17 00:00:00 2001 From: Thomas Wang Date: Wed, 12 Nov 2025 15:52:20 -0800 Subject: [PATCH] FragmentedPart (#1818) Summary: This sets us up for the next diff by creating `FragmentedPart` Currently our pickle is still not truly zero copy because the Pickler calls `Buffer::write()` which is copying bytes from `PyBytes` to `BytesMut` via `extend_from_slice()`. This is especially problematic for large messages (100KB+) as we are spending a lot of CPU cycles handling page faults. For a 1MB message pickling can take as long as 600us To avoid copies, we can just make `Buffer` backed by a `Vec` with each call to `Buffer::write()` pushing the PyBytes to the Vec. As a result of this, the PyBytes are physically fragmented despite being logically contiguous. To make this work, we will have a NewType with called `FragmentedPart` with a `::Fragmented` variant wrapping `Vec` and a `::Contiguous` variant wrapping `Part`. Similar to `Part`, `FragmentedPart` also just collects during serialization. When we receive the frame on the other end of the wire, we reconstruct it contiguously in the `FragmentedPart::Contiguous` variant so that we can easily consume it to create a single contiguous `bytes::Bytes` Differential Revision: D86696390 --- serde_multipart/src/de/bincode.rs | 25 +++- serde_multipart/src/lib.rs | 211 +++++++++++++++++++++++++---- serde_multipart/src/part.rs | 162 ++++++++++++++++++++++ serde_multipart/src/ser/bincode.rs | 14 +- 4 files changed, 378 insertions(+), 34 deletions(-) diff --git a/serde_multipart/src/de/bincode.rs b/serde_multipart/src/de/bincode.rs index 849f45b3b..df6330645 100644 --- a/serde_multipart/src/de/bincode.rs +++ b/serde_multipart/src/de/bincode.rs @@ -15,21 +15,32 @@ use bincode::ErrorKind; use bincode::Options; use serde::de::IntoDeserializer; +use crate::FragmentedPart; use crate::part::Part; /// Multipart deserializer for bincode. This passes through to the underlying bincode -/// deserializer, but dequeues serialized parts when they are needed by [`Part::deserialize`]. +/// deserializer, but dequeues serialized parts when they are needed by +/// [`Part::deserialize`] or [`FragmentedPart::deserialize`]. pub struct Deserializer { de: bincode::Deserializer, parts: VecDeque, + fragmented_parts: VecDeque, } impl Deserializer where O: Options, { - pub(crate) fn new(de: bincode::Deserializer, parts: VecDeque) -> Self { - Self { de, parts } + pub(crate) fn new( + de: bincode::Deserializer, + parts: VecDeque, + fragmented_parts: VecDeque, + ) -> Self { + Self { + de, + parts, + fragmented_parts, + } } pub(crate) fn deserialize_part(&mut self) -> Result { @@ -38,8 +49,14 @@ where }) } + pub(crate) fn deserialize_fragmented_part(&mut self) -> Result { + self.fragmented_parts.pop_front().ok_or_else(|| { + ErrorKind::Custom("fragmented part underrun while decoding".to_string()).into() + }) + } + pub(crate) fn end(self) -> Result<(), Error> { - if self.parts.is_empty() { + if self.parts.is_empty() && self.fragmented_parts.is_empty() { Ok(()) } else { Err(ErrorKind::Custom("multipart overrun while decoding".to_string()).into()) diff --git a/serde_multipart/src/lib.rs b/serde_multipart/src/lib.rs index eb6c6d1ed..56caa5594 100644 --- a/serde_multipart/src/lib.rs +++ b/serde_multipart/src/lib.rs @@ -8,9 +8,9 @@ //! Serde codec for multipart messages. //! -//! Using [`serialize`] / [`deserialize`], fields typed [`Part`] are extracted -//! from the main payload and appended to a list of `parts`. Each part is backed by -//! [`bytes::Bytes`] for cheap, zero-copy sharing. +//! Using [`serialize`] / [`deserialize`], fields typed [`Part`] or [`FragmentedPart`] +//! are extracted from the main payload and appended to lists of parts. Each part is +//! backed by [`bytes::Bytes`] for cheap, zero-copy sharing. //! //! On decode, the body and its parts are reassembled into the original value //! without copying. @@ -20,11 +20,11 @@ //! efficient network I/O without compacting data into a single buffer. //! //! Implementation note: this crate uses Rust's min_specialization feature to enable -//! the use of [`Part`]s with any Serde serializer or deserializer. This feature -//! is fairly restrictive, and thus the API offered by [`serialize`] / [`deserialize`] +//! the use of [`Part`]s and [`FragmentedPart`]s with any Serde serializer or deserializer. +//! This feature is fairly restrictive, and thus the API offered by [`serialize`] / [`deserialize`] //! is not customizable. If customization is needed, you need to add specialization -//! implementations for these codecs. See [`part::PartSerializer`] and [`part::PartDeserializer`] -//! for details. +//! implementations for these codecs. See [`part::PartSerializer`], [`part::PartDeserializer`], +//! [`FragmentedPartSerializer`], and [`FragmentedPartDeserializer`] for details. #![feature(min_specialization)] #![feature(assert_matches)] @@ -46,6 +46,7 @@ mod part; mod ser; use bytes::Bytes; use bytes::BytesMut; +pub use part::FragmentedPart; pub use part::Part; use serde::Deserialize; use serde::Serialize; @@ -57,6 +58,7 @@ use serde::Serialize; pub struct Message { body: Part, parts: Vec, + fragmented_parts: Vec, is_illegal: bool, } @@ -66,6 +68,7 @@ impl Message { Self { body, parts, + fragmented_parts: vec![], is_illegal: false, } } @@ -87,60 +90,147 @@ impl Message { /// Returns the total size (in bytes) of the message. pub fn len(&self) -> usize { - self.body.len() + self.parts.iter().map(|part| part.len()).sum::() + self.body.len() + + self.parts.iter().map(|part| part.len()).sum::() + + self + .fragmented_parts + .iter() + .map(|fp| fp.len()) + .sum::() } /// Returns whether the message is empty. It is always false, since the body /// is always defined. pub fn is_empty(&self) -> bool { - self.body.is_empty() && self.parts.iter().all(|part| part.is_empty()) + self.body.is_empty() + && self.parts.iter().all(|part| part.is_empty()) + && self.fragmented_parts.iter().all(|fp| fp.is_empty()) } /// Convert this message into its constituent components. - pub fn into_inner(self) -> (Part, Vec) { - (self.body, self.parts) + pub fn into_inner(self) -> (Part, Vec, Vec) { + (self.body, self.parts, self.fragmented_parts) } /// Returns the total size (in bytes) of the message when it is framed. pub fn frame_len(&self) -> usize { - 8 * (1 + self.num_parts()) + self.len() + if self.is_illegal { + // Illegal messages use a simplified frame format: u64::MAX marker + body + return 8 + self.body.len(); + } + + // Headers: body_len (8) + num_regular_parts (8) + num_fragmented (8) + let header_bytes = 3 * 8; + + let body_bytes = self.body.len(); + + let regular_parts_bytes = + self.parts.len() * 8 + self.parts.iter().map(|p| p.len()).sum::(); + + let fragmented_parts_bytes = self.fragmented_parts.len() * 8 + + self + .fragmented_parts + .iter() + .map(|fp| fp.len()) + .sum::(); + + header_bytes + body_bytes + regular_parts_bytes + fragmented_parts_bytes } /// Efficiently frames a message containing the body and all of its parts /// using a simple frame-length encoding: /// /// ```text - /// +--------------------+-------------------+--------------------+-------------------+ ... + - /// | body_len (u64 BE) | body bytes | part1_len (u64 BE) | part1 bytes | | - /// +--------------------+-------------------+--------------------+-------------------+ + - /// repeat - /// for - /// each part + /// ┌─────────────────────────┐ + /// │ body_len (u64 BE) │ + /// ├─────────────────────────┤ + /// │ body bytes │ + /// ├─────────────────────────┤ + /// │ num_parts (u64 BE) │ + /// ├─────────────────────────┤ + /// │ part1_len (u64 BE) │ + /// ├─────────────────────────┤ + /// │ part1 bytes │ + /// ├─────────────────────────┤ + /// │ part2_len (u64 BE) │ + /// ├─────────────────────────┤ + /// │ part2 bytes │ + /// ├─────────────────────────┤ + /// │ ... │ + /// ├─────────────────────────┤ + /// │ num_fragmented (u64 BE) │ + /// ├─────────────────────────┤ + /// │ frag1_len (u64 BE) │ + /// ├─────────────────────────┤ + /// │ frag1 bytes │ + /// ├─────────────────────────┤ + /// │ frag2_len (u64 BE) │ + /// ├─────────────────────────┤ + /// │ frag2 bytes │ + /// ├─────────────────────────┤ + /// │ ... │ + /// └─────────────────────────┘ /// ``` pub fn framed(self) -> Frame { let is_illegal = self.is_illegal; - let (body, parts) = self.into_inner(); + let (body, parts, fragmented_parts) = self.into_inner(); if is_illegal { - assert!(parts.is_empty(), "illegal illegal message"); + assert!( + parts.is_empty() && fragmented_parts.is_empty(), + "illegal illegal message" + ); return Frame::from_buffers(vec![ Bytes::from_owner(u64::MAX.to_be_bytes()), body.into_inner(), ]); } - let mut buffers = Vec::with_capacity(2 + 2 * parts.len()); + let has_fragmented = !fragmented_parts.is_empty(); + + let fragmented_total_parts: usize = + fragmented_parts.iter().map(|fp| fp.as_slice().len()).sum(); + let mut buffers = Vec::with_capacity( + 3 + // body_len + body + num_regular_parts + 2 * parts.len() + // Regular parts (len + data each) + if has_fragmented { 1 + fragmented_parts.len() + fragmented_total_parts } else { 0 }, + ); let body = body.into_inner(); buffers.push(Bytes::from_owner(body.len().to_be_bytes())); buffers.push(body); + // Number of regular parts + buffers.push(Bytes::from_owner(parts.len().to_be_bytes())); + for part in parts { let part = part.into_inner(); + // Length of this part buffers.push(Bytes::from_owner(part.len().to_be_bytes())); + buffers.push(part); } + if has_fragmented { + // Number of FragmentedParts + buffers.push(Bytes::from_owner(fragmented_parts.len().to_be_bytes())); + + for frag_part in fragmented_parts { + let parts = frag_part.into_parts(); + // Length of all parts/fragments + buffers.push(Bytes::from_owner( + (parts.iter().map(|p| p.len()).sum::() as u64).to_be_bytes(), + )); + + for part in parts { + buffers.push(part.into_inner()); + } + } + } else { + // Write 0 for num_fragmented if there are none + buffers.push(Bytes::from_owner(0u64.to_be_bytes())); + } + Frame::from_buffers(buffers) } @@ -154,18 +244,44 @@ impl Message { return Ok(Self { body: buf.into(), parts: vec![], + fragmented_parts: vec![], is_illegal: true, }); } let body = buf.split_to(body_len as usize); - let mut parts = Vec::new(); - while !buf.is_empty() { + + // Read number of regular parts + if buf.len() < 8 { + return Err(std::io::ErrorKind::UnexpectedEof.into()); + } + let num_regular_parts = buf.get_u64() as usize; + + if buf.len() < 8 { + return Err(std::io::ErrorKind::UnexpectedEof.into()); + } + + let mut parts = Vec::with_capacity(num_regular_parts); + for _ in 0..num_regular_parts { parts.push(Self::split_part(&mut buf)?.into()); } + + if buf.len() < 8 { + return Err(std::io::ErrorKind::UnexpectedEof.into()); + } + let num_fragmented = buf.get_u64() as usize; + + let mut fragmented_parts = Vec::with_capacity(num_fragmented); + for _ in 0..num_fragmented { + fragmented_parts.push(FragmentedPart::Contiguous( + Self::split_part(&mut buf)?.into(), + )); + } + Ok(Self { body: body.into(), parts, + fragmented_parts, is_illegal: false, }) } @@ -322,9 +438,11 @@ pub fn serialize_bincode( let mut serializer: part::BincodeSerializer = ser::bincode::Serializer::new(bincode::Serializer::new(buffer_borrow.writer(), options())); value.serialize(&mut serializer)?; + let (parts, fragmented_parts) = serializer.into_parts(); Ok(Message { body: Part(buffer.into_inner().freeze()), - parts: serializer.into_parts(), + parts, + fragmented_parts, is_illegal: false, }) } @@ -336,17 +454,18 @@ where T: serde::de::DeserializeOwned, { if message.is_illegal { - let (body, parts) = message.into_inner(); - if !parts.is_empty() { + let (body, parts, fragmented_parts) = message.into_inner(); + if !parts.is_empty() || !fragmented_parts.is_empty() { return Err(bincode::ErrorKind::Custom("illegal illegal message".to_string()).into()); } return bincode::deserialize_from(body.into_inner().reader()); } - let (body, parts) = message.into_inner(); + let (body, parts, fragmented_parts) = message.into_inner(); let mut deserializer = part::BincodeDeserializer::new( bincode::Deserializer::with_reader(body.into_inner().reader(), options()), parts.into(), + fragmented_parts.into(), ); let value = T::deserialize(&mut deserializer)?; // Check that all parts were consumed: @@ -366,6 +485,7 @@ pub fn serialize_illegal_bincode( Ok(Message { body: Part::from(bincode::serialize(value)?), parts: vec![], + fragmented_parts: vec![], is_illegal: true, }) } @@ -507,6 +627,7 @@ mod tests { let message = Message { body: Part::from("hello"), parts: vec![Part::from("world")], + fragmented_parts: vec![], is_illegal: false, }; let err = deserialize_bincode::(message).unwrap_err(); @@ -565,6 +686,7 @@ mod tests { Part::from("xyz"), Part::from("xyzd"), ], + fragmented_parts: vec![], is_illegal: false, }; @@ -573,6 +695,41 @@ mod tests { assert_eq!(Message::from_framed(framed).unwrap(), message); } + #[test] + fn test_fragmented_part_roundtrip() { + let fragments = vec![ + Part::from("Hello"), + Part::from(" "), + Part::from("World"), + Part::from("!"), + ]; + let expected_data = b"Hello World!"; + + let fragmented = FragmentedPart::Fragmented(fragments); + assert!(matches!(fragmented, FragmentedPart::Fragmented(_))); + + #[derive(Serialize, Deserialize, Debug)] + struct TestStruct { + data: FragmentedPart, + } + + let test_struct = TestStruct { data: fragmented }; + + let message = serialize_bincode(&test_struct).unwrap(); + + let mut framed = message.framed(); + let framed_bytes = framed.copy_to_bytes(framed.remaining()); + + let unframed_message = Message::from_framed(framed_bytes).unwrap(); + + let deserialized: TestStruct = deserialize_bincode(unframed_message).unwrap(); + + assert!(matches!(deserialized.data, FragmentedPart::Contiguous(_))); + + let contiguous_bytes = deserialized.data.into_bytes(); + assert_eq!(&*contiguous_bytes, expected_data); + } + #[test] fn test_socket_addr() { let socket_addr_v6: SocketAddrV6 = diff --git a/serde_multipart/src/part.rs b/serde_multipart/src/part.rs index a90b3358c..8678e1df9 100644 --- a/serde_multipart/src/part.rs +++ b/serde_multipart/src/part.rs @@ -9,6 +9,7 @@ use std::ops::Deref; use bytes::Bytes; +use bytes::BytesMut; use bytes::buf::Reader as BufReader; use bytes::buf::Writer as BufWriter; use serde::Deserialize; @@ -124,3 +125,164 @@ impl<'de, 'a> PartDeserializer<'de, &'a mut BincodeDeserializer> for Part { deserializer.deserialize_part() } } + +/// A logically contiguous part that may be physically fragmented or contiguous. +/// +/// During serialization, parts are extracted separately (allowing zero-copy from construction). +/// During deserialization, data arrives as a single contiguous `Part`. +/// +/// Use this when: +/// - Construction creates multiple Parts (e.g., multiple pickle writes to a Buffer) +/// - Consumption needs contiguous bytes (e.g., unpickling requires contiguous buffer) +/// - Network read already gives contiguous bytes (no need to split and re-concat) +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum FragmentedPart { + /// Multiple fragments that need to be concatenated when accessed + Fragmented(Vec), + /// Already contiguous data (typically from deserialization) + Contiguous(Part), +} + +impl Default for FragmentedPart { + fn default() -> Self { + Self::Contiguous(Part::default()) + } +} + +impl FragmentedPart { + pub fn new(parts: Vec) -> Self { + if parts.len() == 1 { + Self::Contiguous(parts.into_iter().next().unwrap()) + } else { + Self::Fragmented(parts) + } + } + + pub fn into_parts(self) -> Vec { + match self { + Self::Fragmented(parts) => parts, + Self::Contiguous(part) => vec![part], + } + } + + /// Convert into bytes, concatenating fragments if necessary. + pub fn into_bytes(self) -> Bytes { + match self { + Self::Contiguous(part) => part.into_inner(), + Self::Fragmented(parts) => { + let total_len: usize = parts.iter().map(|p| p.len()).sum(); + let mut result = BytesMut::with_capacity(total_len); + for part in parts { + result.extend_from_slice(&part.to_bytes()); + } + result.freeze() + } + } + } + + /// Get bytes as a reference, concatenating fragments if necessary. + pub fn as_bytes(&self) -> Bytes { + match self { + Self::Contiguous(part) => part.to_bytes(), + Self::Fragmented(parts) => { + let total_len: usize = parts.iter().map(|p| p.len()).sum(); + let mut result = BytesMut::with_capacity(total_len); + for part in parts { + result.extend_from_slice(&part.to_bytes()); + } + result.freeze() + } + } + } + + pub fn as_slice(&self) -> &[Part] { + match self { + Self::Fragmented(parts) => parts.as_slice(), + Self::Contiguous(part) => std::slice::from_ref(part), + } + } + + /// Returns the total length in bytes of the fragmented part. + /// For contiguous parts, this is just the part length. + /// For fragmented parts, this is the sum of all fragment lengths. + pub fn len(&self) -> usize { + match self { + Self::Contiguous(part) => part.len(), + Self::Fragmented(parts) => parts.iter().map(|p| p.len()).sum(), + } + } + + /// Returns whether the fragmented part is empty. + pub fn is_empty(&self) -> bool { + self.len() == 0 + } +} + +/// Serialization trait for FragmentedPart (similar to PartSerializer) +trait FragmentedPartSerializer { + fn serialize(parts: &FragmentedPart, s: S) -> Result; +} + +/// Default: serialize as Vec +impl FragmentedPartSerializer for FragmentedPart { + default fn serialize(part: &FragmentedPart, s: S) -> Result { + match part { + FragmentedPart::Fragmented(parts) => parts.serialize(s), + FragmentedPart::Contiguous(part) => vec![part.clone()].serialize(s), + } + } +} + +/// Specialized for our BincodeSerializer +impl<'a> FragmentedPartSerializer<&'a mut BincodeSerializer> for FragmentedPart { + fn serialize( + parts: &FragmentedPart, + s: &'a mut BincodeSerializer, + ) -> Result<(), bincode::Error> { + // Tell the serializer to extract this as a fragmented part + s.serialize_fragmented_part(parts); + // Serialize as empty Vec in the body (parts are extracted) + Vec::::new().serialize(s) + } +} + +impl Serialize for FragmentedPart { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + >::serialize(self, serializer) + } +} + +/// Deserialization trait for FragmentedPart +trait FragmentedPartDeserializer<'de, D: serde::Deserializer<'de>>: Sized { + fn deserialize(d: D) -> Result; +} + +/// Default: deserialize as Vec +impl<'de, D: serde::Deserializer<'de>> FragmentedPartDeserializer<'de, D> for FragmentedPart { + default fn deserialize(deserializer: D) -> Result { + let parts = Vec::::deserialize(deserializer)?; + Ok(Self::new(parts)) + } +} + +/// Specialized for our BincodeDeserializer +impl<'de, 'a> FragmentedPartDeserializer<'de, &'a mut BincodeDeserializer> for FragmentedPart { + fn deserialize(deserializer: &'a mut BincodeDeserializer) -> Result { + // Read the Vec (should be empty from serialization) + let _empty: Vec = Vec::deserialize(&mut *deserializer)?; + // Pull the actual fragmented part from the deserializer + deserializer.deserialize_fragmented_part() + } +} + +impl<'de> Deserialize<'de> for FragmentedPart { + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + >::deserialize(deserializer) + } +} diff --git a/serde_multipart/src/ser/bincode.rs b/serde_multipart/src/ser/bincode.rs index 36b54b753..dbae7e709 100644 --- a/serde_multipart/src/ser/bincode.rs +++ b/serde_multipart/src/ser/bincode.rs @@ -13,13 +13,15 @@ use ::bincode::Options; use serde::Serialize; use serde::ser; +use crate::FragmentedPart; use crate::Part; /// Multipart serializer for bincode. This passes through serialization to bincode, -/// but also records the parts encoded by [`Part::serialize`]. +/// but also records the parts encoded by [`Part::serialize`] and [`FragmentedPart::serialize`]. pub struct Serializer { ser: ::bincode::Serializer, parts: Vec, + fragmented_parts: Vec, } impl Serializer { @@ -27,6 +29,7 @@ impl Serializer { Self { ser, parts: Vec::new(), + fragmented_parts: Vec::new(), } } @@ -35,8 +38,13 @@ impl Serializer { self.parts.push(part.clone()); } - pub(crate) fn into_parts(self) -> Vec { - self.parts + /// Serialize a FragmentedPart by appending it to the fragmented_parts list. + pub(crate) fn serialize_fragmented_part(&mut self, parts: &FragmentedPart) { + self.fragmented_parts.push(parts.clone()); + } + + pub(crate) fn into_parts(self) -> (Vec, Vec) { + (self.parts, self.fragmented_parts) } }