diff --git a/monarch_extension/src/mesh_controller.rs b/monarch_extension/src/mesh_controller.rs index 8d7ed8738..8816479bf 100644 --- a/monarch_extension/src/mesh_controller.rs +++ b/monarch_extension/src/mesh_controller.rs @@ -44,7 +44,7 @@ use hyperactor_mesh::v1::Name; use hyperactor_mesh_macros::sel; use monarch_hyperactor::actor::PythonMessage; use monarch_hyperactor::actor::PythonMessageKind; -use monarch_hyperactor::buffers::FrozenBuffer; +use monarch_hyperactor::buffers::Buffer; use monarch_hyperactor::context::PyInstance; use monarch_hyperactor::instance_dispatch; use monarch_hyperactor::local_state_broker::LocalStateBrokerActor; @@ -571,10 +571,10 @@ impl History { let exe = remote_exception .call1((exception.backtrace, traceback, rank)) .unwrap(); - let data: FrozenBuffer = pickle.call1((exe,)).unwrap().extract().unwrap(); - PythonMessage::new_from_buf( + let mut data: Buffer = pickle.call1((exe,)).unwrap().extract().unwrap(); + PythonMessage::new_from_fragmented( PythonMessageKind::Exception { rank: Some(rank) }, - data.inner, + data.into_fragmented_part(), ) })); diff --git a/monarch_hyperactor/src/actor.rs b/monarch_hyperactor/src/actor.rs index 3e3b1c8b4..76abe77b0 100644 --- a/monarch_hyperactor/src/actor.rs +++ b/monarch_hyperactor/src/actor.rs @@ -47,11 +47,13 @@ use pyo3::types::PyType; use serde::Deserialize; use serde::Serialize; use serde_bytes::ByteBuf; +use serde_multipart::FragmentedPart; use serde_multipart::Part; use tokio::sync::Mutex; use tokio::sync::oneshot; use tracing::Instrument; +use crate::buffers::Buffer; use crate::buffers::FrozenBuffer; use crate::config::SHARED_ASYNCIO_RUNTIME; use crate::context::PyInstance; @@ -265,7 +267,7 @@ fn mailbox<'py, T: Actor>(py: Python<'py>, cx: &Context<'_, T>) -> Bound<'py, Py #[derive(Clone, Serialize, Deserialize, Named, PartialEq, Default)] pub struct PythonMessage { pub kind: PythonMessageKind, - pub message: Part, + pub message: FragmentedPart, } struct ResolvedCallMethod { @@ -281,7 +283,14 @@ impl PythonMessage { pub fn new_from_buf(kind: PythonMessageKind, message: impl Into) -> Self { Self { kind, - message: message.into(), + message: FragmentedPart::Contiguous(message.into()), + } + } + + pub fn new_from_fragmented(kind: PythonMessageKind, fragmented_part: FragmentedPart) -> Self { + Self { + kind, + message: fragmented_part, } } @@ -336,7 +345,7 @@ impl PythonMessage { Ok(ResolvedCallMethod { method: name, bytes: FrozenBuffer { - inner: self.message.into_inner(), + inner: self.message.into_bytes(), }, local_state, response_port, @@ -375,7 +384,7 @@ impl PythonMessage { Ok(ResolvedCallMethod { method: name, bytes: FrozenBuffer { - inner: self.message.into_inner(), + inner: self.message.into_bytes(), }, local_state, response_port, @@ -394,7 +403,7 @@ impl std::fmt::Debug for PythonMessage { .field("kind", &self.kind) .field( "message", - &hyperactor::data::HexFmt(&(*self.message)[..]).to_string(), + &hyperactor::data::HexFmt(&(*self.message.as_bytes())[..]).to_string(), ) .finish() } @@ -423,9 +432,11 @@ impl PythonMessage { #[new] #[pyo3(signature = (kind, message))] pub fn new<'py>(kind: PythonMessageKind, message: Bound<'py, PyAny>) -> PyResult { - if let Ok(buff) = message.extract::>() { - let frozen = buff.borrow_mut(); - return Ok(PythonMessage::new_from_buf(kind, frozen.inner.clone())); + if let Ok(mut buff) = message.extract::>() { + return Ok(PythonMessage::new_from_fragmented( + kind, + buff.into_fragmented_part(), + )); } else if let Ok(buff) = message.extract::>() { return Ok(PythonMessage::new_from_buf( kind, @@ -446,7 +457,7 @@ impl PythonMessage { #[getter] fn message(&self) -> FrozenBuffer { FrozenBuffer { - inner: self.message.clone().into_inner(), + inner: self.message.as_bytes(), } } } @@ -1001,7 +1012,7 @@ mod tests { }, response_port: Some(EitherPortRef::Unbounded(port_ref.clone().into())), }, - message: Part::from(vec![1, 2, 3]), + message: FragmentedPart::Contiguous(Part::from(vec![1, 2, 3])), }; { let mut erased = ErasedUnbound::try_from_message(message.clone()).unwrap(); diff --git a/monarch_hyperactor/src/buffers.rs b/monarch_hyperactor/src/buffers.rs index d34b22561..bb605c1fb 100644 --- a/monarch_hyperactor/src/buffers.rs +++ b/monarch_hyperactor/src/buffers.rs @@ -12,7 +12,6 @@ use std::ffi::c_int; use std::ffi::c_void; use bytes::Buf; -use bytes::BytesMut; use hyperactor::Named; use pyo3::buffer::PyBuffer; use pyo3::prelude::*; @@ -20,78 +19,85 @@ use pyo3::types::PyBytes; use pyo3::types::PyBytesMethods; use serde::Deserialize; use serde::Serialize; +use serde_multipart::FragmentedPart; +use serde_multipart::Part; + +/// Wrapper that keeps Py alive while allowing zero-copy access to its memory +struct PyBytesWrapper { + _py_bytes: Py, + ptr: *const u8, + len: usize, +} + +impl PyBytesWrapper { + fn new(py_bytes: Py) -> Self { + let (ptr, len) = Python::with_gil(|py| { + let bytes_ref = py_bytes.as_bytes(py); + (bytes_ref.as_ptr(), bytes_ref.len()) + }); + Self { + _py_bytes: py_bytes, + ptr, + len, + } + } +} + +impl AsRef<[u8]> for PyBytesWrapper { + fn as_ref(&self) -> &[u8] { + // SAFETY: ptr is valid as long as py_bytes is alive (kept alive by Py) + // Python won't free the memory until the Py refcount reaches 0 + unsafe { std::slice::from_raw_parts(self.ptr, self.len) } + } +} + +// SAFETY: Py is Send/Sync for immutable bytes +unsafe impl Send for PyBytesWrapper {} +// SAFETY: Py is Send/Sync for immutable bytes +unsafe impl Sync for PyBytesWrapper {} /// A mutable buffer for reading and writing bytes data. /// -/// The `Buffer` struct provides an interface for accumulating byte data that can be written to -/// and then frozen into an immutable `FrozenBuffer` for reading. It uses the `bytes::BytesMut` -/// internally for efficient memory management. +/// The `Buffer` struct provides an interface for accumulating byte data from Python `bytes` objects +/// that can be converted into a `FragmentedPart` for zero-copy multipart message serialization. +/// It accumulates references to Python bytes objects without copying. /// /// # Examples /// /// ```python /// from monarch._rust_bindings.monarch_hyperactor.buffers import Buffer /// -/// # Create a new buffer with default capacity (4096 bytes) +/// # Create a new buffer /// buffer = Buffer() /// /// # Write some data /// data = b"Hello, World!" /// bytes_written = buffer.write(data) /// -/// # Check length -/// print(len(buffer)) # 13 -/// -/// # Freeze for reading -/// frozen = buffer.freeze() -/// content = frozen.read() +/// # Use in multipart serialization +/// # The buffer accumulates multiple writes as separate fragments /// ``` #[pyclass(subclass, module = "monarch._rust_bindings.monarch_hyperactor.buffers")] -#[derive(Clone, Serialize, Deserialize, Named, PartialEq, Default)] +#[derive(Clone, Default)] pub struct Buffer { - pub(crate) inner: bytes::BytesMut, -} - -impl Buffer { - /// Consumes the Buffer and returns the underlying BytesMut. - /// This allows zero-copy access to the raw buffer data. - pub fn into_inner(self) -> bytes::BytesMut { - self.inner - } -} - -impl From for Buffer -where - T: Into, -{ - fn from(value: T) -> Self { - Self { - inner: value.into(), - } - } + inner: Vec>, } #[pymethods] impl Buffer { /// Creates a new empty buffer with specified initial capacity. /// - /// # Arguments - /// * `size` - Initial capacity in bytes (default: 4096) /// /// # Returns /// A new empty `Buffer` instance with the specified capacity. #[new] - #[pyo3(signature=(size=4096))] - fn new(size: usize) -> Self { - Self { - inner: bytes::BytesMut::with_capacity(size), - } + fn new() -> Self { + Self { inner: Vec::new() } } /// Writes bytes data to the buffer. /// - /// Appends the provided bytes to the end of the buffer, extending its capacity - /// if necessary. + /// This keeps a reference to the Python bytes object without copying. /// /// # Arguments /// * `buff` - The bytes object to write to the buffer @@ -100,26 +106,49 @@ impl Buffer { /// The number of bytes written (always equal to the length of input bytes) fn write<'py>(&mut self, buff: &Bound<'py, PyBytes>) -> usize { let bytes_written = buff.as_bytes().len(); - self.inner.extend_from_slice(buff.as_bytes()); + self.inner.push(buff.clone().unbind()); bytes_written } - /// Freezes this buffer into an immutable `FrozenBuffer`. + /// Freezes the buffer, converting it into an immutable `FrozenBuffer` for reading. + /// + /// This consumes all accumulated PyBytes and converts them into a contiguous bytes buffer. + /// After freezing, the original buffer is cleared. /// - /// This operation consumes the mutable buffer's contents, transferring ownership - /// to a new `FrozenBuffer` that can only be read from. The original buffer - /// becomes empty after this operation. + /// This operation should avoided in hot paths as it creates a copy in order to concatenate + /// bytes that are fragmented in memory into a single series of contiguous bytes /// /// # Returns - /// A new `FrozenBuffer` containing all the data that was in this buffer + /// A new `FrozenBuffer` containing all the bytes that were written to this buffer fn freeze(&mut self) -> FrozenBuffer { - let buff = std::mem::take(&mut self.inner); + let fragmented_part = self.into_fragmented_part(); FrozenBuffer { - inner: buff.freeze(), + inner: fragmented_part.into_bytes(), } } } +impl Buffer { + /// Converts accumulated `PyBytes` objects to [`FragmentedPart`] for zero-copy multipart messages. + /// + /// Returns a `FragmentedPart::Fragmented` variant since the buffer accumulates multiple + /// separate PyBytes objects that remain physically fragmented. + pub fn into_fragmented_part(&mut self) -> FragmentedPart { + let inner = std::mem::take(&mut self.inner); + + FragmentedPart::Fragmented( + inner + .into_iter() + .map(|py_bytes| { + let wrapper = PyBytesWrapper::new(py_bytes); + let bytes = bytes::Bytes::from_owner(wrapper); + Part::from(bytes) + }) + .collect(), + ) + } +} + /// An immutable buffer for reading bytes data. /// /// The `FrozenBuffer` struct provides a read-only interface to byte data. Once created, diff --git a/monarch_tensor_worker/src/stream.rs b/monarch_tensor_worker/src/stream.rs index 02e468032..e9681a281 100644 --- a/monarch_tensor_worker/src/stream.rs +++ b/monarch_tensor_worker/src/stream.rs @@ -37,7 +37,7 @@ use hyperactor::mailbox::PortReceiver; use hyperactor::proc::Proc; use monarch_hyperactor::actor::PythonMessage; use monarch_hyperactor::actor::PythonMessageKind; -use monarch_hyperactor::buffers::FrozenBuffer; +use monarch_hyperactor::buffers::Buffer; use monarch_hyperactor::local_state_broker::BrokerId; use monarch_hyperactor::local_state_broker::LocalState; use monarch_hyperactor::local_state_broker::LocalStateBrokerMessage; @@ -103,16 +103,16 @@ fn pickle_python_result( .unwrap() .getattr("_pickle") .unwrap(); - let data: FrozenBuffer = pickle + let mut data: Buffer = pickle .call1((result,)) .map_err(|pyerr| anyhow::Error::from(SerializablePyErr::from(py, &pyerr)))? .extract() .unwrap(); - Ok(PythonMessage::new_from_buf( + Ok(PythonMessage::new_from_fragmented( PythonMessageKind::Result { rank: Some(worker_rank), }, - data.inner, + data.into_fragmented_part(), )) } diff --git a/python/monarch/_rust_bindings/monarch_hyperactor/actor.pyi b/python/monarch/_rust_bindings/monarch_hyperactor/actor.pyi index 8ce6ea25b..796a1ea15 100644 --- a/python/monarch/_rust_bindings/monarch_hyperactor/actor.pyi +++ b/python/monarch/_rust_bindings/monarch_hyperactor/actor.pyi @@ -23,7 +23,7 @@ from typing import ( Union, ) -from monarch._rust_bindings.monarch_hyperactor.buffers import FrozenBuffer +from monarch._rust_bindings.monarch_hyperactor.buffers import Buffer, FrozenBuffer from monarch._rust_bindings.monarch_hyperactor.mailbox import ( Mailbox, @@ -204,7 +204,7 @@ class PythonMessage: def __init__( self, kind: PythonMessageKind, - message: Union[FrozenBuffer, bytes], + message: Union[Buffer, bytes], ) -> None: ... @property def message(self) -> FrozenBuffer: diff --git a/python/monarch/_rust_bindings/monarch_hyperactor/buffers.pyi b/python/monarch/_rust_bindings/monarch_hyperactor/buffers.pyi index 08b16e373..4fbacf306 100644 --- a/python/monarch/_rust_bindings/monarch_hyperactor/buffers.pyi +++ b/python/monarch/_rust_bindings/monarch_hyperactor/buffers.pyi @@ -130,36 +130,29 @@ class Buffer: """ A mutable buffer for reading and writing bytes data. - The `Buffer` struct provides an interface for accumulating byte data that can be written to - and then frozen into an immutable `FrozenBuffer` for reading. It uses the `bytes::BytesMut` - internally for efficient memory management. + The `Buffer` struct provides an interface for accumulating byte data from Python `bytes` objects + that can be converted into a `FragmentedPart` for zero-copy multipart message serialization. + It accumulates references to Python bytes objects without copying. Examples: ```python from monarch._rust_bindings.monarch_hyperactor.buffers import Buffer - # Create a new buffer with default capacity (4096 bytes) + # Create a new buffer buffer = Buffer() # Write some data data = b"Hello, World!" bytes_written = buffer.write(data) - # Check length - print(len(buffer)) # 13 - - # Freeze for reading - frozen = buffer.freeze() - content = frozen.read() + # Use in multipart serialization + # The buffer accumulates multiple writes as separate fragments ``` """ - def __init__(self, size: int = 4096) -> None: + def __init__(self) -> None: """ - Create a new empty buffer with specified initial capacity. - - Arguments: - - `size`: Initial capacity in bytes (default: 4096) + Create a new empty buffer. """ ... @@ -167,8 +160,7 @@ class Buffer: """ Write bytes data to the buffer. - Appends the provided bytes to the end of the buffer, extending its capacity - if necessary. + This keeps a reference to the Python bytes object without copying. Arguments: - `buff`: The bytes object to write to the buffer @@ -178,15 +170,6 @@ class Buffer: """ ... - def __len__(self) -> int: - """ - Return the number of bytes remaining in the buffer. - - Returns: - The number of bytes that can be read from the buffer - """ - ... - def freeze(self) -> FrozenBuffer: """ Freeze this buffer into an immutable `FrozenBuffer`. @@ -195,6 +178,9 @@ class Buffer: to a new `FrozenBuffer` that can only be read from. The original buffer becomes empty after this operation. + This operation should generally be avoided in hot paths as it creates copies in order to concatenate + bytes that are potentially fragmented in memory into a single contiguous series of bytes + Returns: A new `FrozenBuffer` containing all the data that was in this buffer """ diff --git a/python/monarch/_src/actor/actor_mesh.py b/python/monarch/_src/actor/actor_mesh.py index 43aa91758..0241b0f48 100644 --- a/python/monarch/_src/actor/actor_mesh.py +++ b/python/monarch/_src/actor/actor_mesh.py @@ -52,7 +52,7 @@ PythonMessageKind, ) from monarch._rust_bindings.monarch_hyperactor.actor_mesh import PythonActorMesh -from monarch._rust_bindings.monarch_hyperactor.buffers import FrozenBuffer +from monarch._rust_bindings.monarch_hyperactor.buffers import Buffer, FrozenBuffer from monarch._rust_bindings.monarch_hyperactor.channel import ChannelTransport from monarch._rust_bindings.monarch_hyperactor.config import configure from monarch._rust_bindings.monarch_hyperactor.context import Instance as HyInstance @@ -1161,7 +1161,7 @@ def _is_ref_or_mailbox(x: object) -> bool: return hasattr(x, "__monarch_ref__") or isinstance(x, Mailbox) -def _pickle(obj: object) -> bytes | FrozenBuffer: +def _pickle(obj: object) -> Buffer: _, buff = flatten(obj, _is_mailbox) return buff diff --git a/python/monarch/_src/actor/pickle.py b/python/monarch/_src/actor/pickle.py index e430522e2..65a61b2c2 100644 --- a/python/monarch/_src/actor/pickle.py +++ b/python/monarch/_src/actor/pickle.py @@ -135,12 +135,12 @@ def persistent_load(self, pid: Any) -> Any: return self._values[pid] -def flatten(obj: Any, filter: Callable[[Any], bool]) -> Tuple[List[Any], FrozenBuffer]: +def flatten(obj: Any, filter: Callable[[Any], bool]) -> Tuple[List[Any], Buffer]: buffer = Buffer() pickler = _Pickler(filter, buffer) pickler.dump(obj) - return pickler._saved, buffer.freeze() + return pickler._saved, buffer def unflatten(data: FrozenBuffer | bytes, values: Iterable[Any]) -> Any: diff --git a/python/monarch/_src/actor/tensor_engine_shim.py b/python/monarch/_src/actor/tensor_engine_shim.py index 6e41bf39c..51005fc35 100644 --- a/python/monarch/_src/actor/tensor_engine_shim.py +++ b/python/monarch/_src/actor/tensor_engine_shim.py @@ -31,7 +31,7 @@ if TYPE_CHECKING: from monarch._src.actor.actor_mesh import ActorEndpoint, Port, Selection -from monarch._rust_bindings.monarch_hyperactor.buffers import FrozenBuffer +from monarch._rust_bindings.monarch_hyperactor.buffers import Buffer P = ParamSpec("P") F = TypeVar("F", bound=Callable[..., Any]) @@ -70,7 +70,7 @@ def wrap(*args: Any, **kwargs: Any) -> Any: @shim(module="monarch.mesh_controller") def actor_send( endpoint: "ActorEndpoint[..., ...]", - args_kwargs_tuple: bytes, + args_kwargs_tuple: Buffer, refs: "Sequence[Any]", port: "Optional[Port[Any]]", selection: "Selection", @@ -80,7 +80,7 @@ def actor_send( @shim(module="monarch.mesh_controller") def actor_rref( endpoint: Any, - args_kwargs_tuple: FrozenBuffer, + args_kwargs_tuple: Buffer, refs: Sequence[Any], ) -> Any: ... diff --git a/python/monarch/mesh_controller.py b/python/monarch/mesh_controller.py index 7c06af957..e3fe61e02 100644 --- a/python/monarch/mesh_controller.py +++ b/python/monarch/mesh_controller.py @@ -38,7 +38,7 @@ PythonMessageKind, UnflattenArg, ) -from monarch._rust_bindings.monarch_hyperactor.buffers import FrozenBuffer +from monarch._rust_bindings.monarch_hyperactor.buffers import Buffer from monarch._rust_bindings.monarch_hyperactor.mailbox import Mailbox from monarch._rust_bindings.monarch_hyperactor.proc import ( # @manual=//monarch/monarch_extension:monarch_extension ActorId, @@ -308,7 +308,7 @@ def _cast_call_method_indirect( selection: str, client: MeshClient, seq: Seq, - args_kwargs_tuple: FrozenBuffer, + args_kwargs_tuple: Buffer, refs: Sequence[Any], ) -> Tuple[str, int]: unflatten_args = [ @@ -328,7 +328,7 @@ def _cast_call_method_indirect( def actor_send( endpoint: ActorEndpoint, - args_kwargs_tuple: FrozenBuffer, + args_kwargs_tuple: Buffer, refs: Sequence[Any], port: Optional[Port[Any]], selection: str, @@ -370,7 +370,7 @@ def actor_send( def _actor_send( endpoint: ActorEndpoint, - args_kwargs_tuple: FrozenBuffer, + args_kwargs_tuple: Buffer, refs: Sequence[Any], port: Optional[Port[Any]], selection: str, @@ -404,7 +404,7 @@ def _actor_send( client._request_status() -def actor_rref(endpoint, args_kwargs_tuple: FrozenBuffer, refs: Sequence[Any]): +def actor_rref(endpoint, args_kwargs_tuple: Buffer, refs: Sequence[Any]): chosen_stream = stream._active fake_result, dtensors, mutates, mesh = dtensor_check( endpoint._propagate, diff --git a/python/tests/_monarch/test_hyperactor.py b/python/tests/_monarch/test_hyperactor.py index a126b1bf5..2abb52b8b 100644 --- a/python/tests/_monarch/test_hyperactor.py +++ b/python/tests/_monarch/test_hyperactor.py @@ -108,7 +108,6 @@ def test_buffer_read_write() -> None: def test_pickle_to_buffer() -> None: x = [bytes(100000)] - b = Buffer() args, b = flatten(x, lambda x: False) - y = unflatten(b, args) + y = unflatten(b.freeze(), args) assert x == y diff --git a/python/tests/test_host_mesh.py b/python/tests/test_host_mesh.py index 5bb32201b..baf026748 100644 --- a/python/tests/test_host_mesh.py +++ b/python/tests/test_host_mesh.py @@ -115,7 +115,7 @@ def test_pickle() -> None: Extent(["replicas", "hosts"], [2, 4]), ) _unused, pickled = flatten(host, lambda _: False) - unpickled = unflatten(pickled, _unused) + unpickled = unflatten(pickled.freeze(), _unused) assert isinstance(unpickled, HostMesh) assert host.extent.labels == ["replicas", "hosts"] assert host.extent.sizes == [2, 4] diff --git a/serde_multipart/src/de/bincode.rs b/serde_multipart/src/de/bincode.rs index 849f45b3b..df6330645 100644 --- a/serde_multipart/src/de/bincode.rs +++ b/serde_multipart/src/de/bincode.rs @@ -15,21 +15,32 @@ use bincode::ErrorKind; use bincode::Options; use serde::de::IntoDeserializer; +use crate::FragmentedPart; use crate::part::Part; /// Multipart deserializer for bincode. This passes through to the underlying bincode -/// deserializer, but dequeues serialized parts when they are needed by [`Part::deserialize`]. +/// deserializer, but dequeues serialized parts when they are needed by +/// [`Part::deserialize`] or [`FragmentedPart::deserialize`]. pub struct Deserializer { de: bincode::Deserializer, parts: VecDeque, + fragmented_parts: VecDeque, } impl Deserializer where O: Options, { - pub(crate) fn new(de: bincode::Deserializer, parts: VecDeque) -> Self { - Self { de, parts } + pub(crate) fn new( + de: bincode::Deserializer, + parts: VecDeque, + fragmented_parts: VecDeque, + ) -> Self { + Self { + de, + parts, + fragmented_parts, + } } pub(crate) fn deserialize_part(&mut self) -> Result { @@ -38,8 +49,14 @@ where }) } + pub(crate) fn deserialize_fragmented_part(&mut self) -> Result { + self.fragmented_parts.pop_front().ok_or_else(|| { + ErrorKind::Custom("fragmented part underrun while decoding".to_string()).into() + }) + } + pub(crate) fn end(self) -> Result<(), Error> { - if self.parts.is_empty() { + if self.parts.is_empty() && self.fragmented_parts.is_empty() { Ok(()) } else { Err(ErrorKind::Custom("multipart overrun while decoding".to_string()).into()) diff --git a/serde_multipart/src/lib.rs b/serde_multipart/src/lib.rs index eb6c6d1ed..56caa5594 100644 --- a/serde_multipart/src/lib.rs +++ b/serde_multipart/src/lib.rs @@ -8,9 +8,9 @@ //! Serde codec for multipart messages. //! -//! Using [`serialize`] / [`deserialize`], fields typed [`Part`] are extracted -//! from the main payload and appended to a list of `parts`. Each part is backed by -//! [`bytes::Bytes`] for cheap, zero-copy sharing. +//! Using [`serialize`] / [`deserialize`], fields typed [`Part`] or [`FragmentedPart`] +//! are extracted from the main payload and appended to lists of parts. Each part is +//! backed by [`bytes::Bytes`] for cheap, zero-copy sharing. //! //! On decode, the body and its parts are reassembled into the original value //! without copying. @@ -20,11 +20,11 @@ //! efficient network I/O without compacting data into a single buffer. //! //! Implementation note: this crate uses Rust's min_specialization feature to enable -//! the use of [`Part`]s with any Serde serializer or deserializer. This feature -//! is fairly restrictive, and thus the API offered by [`serialize`] / [`deserialize`] +//! the use of [`Part`]s and [`FragmentedPart`]s with any Serde serializer or deserializer. +//! This feature is fairly restrictive, and thus the API offered by [`serialize`] / [`deserialize`] //! is not customizable. If customization is needed, you need to add specialization -//! implementations for these codecs. See [`part::PartSerializer`] and [`part::PartDeserializer`] -//! for details. +//! implementations for these codecs. See [`part::PartSerializer`], [`part::PartDeserializer`], +//! [`FragmentedPartSerializer`], and [`FragmentedPartDeserializer`] for details. #![feature(min_specialization)] #![feature(assert_matches)] @@ -46,6 +46,7 @@ mod part; mod ser; use bytes::Bytes; use bytes::BytesMut; +pub use part::FragmentedPart; pub use part::Part; use serde::Deserialize; use serde::Serialize; @@ -57,6 +58,7 @@ use serde::Serialize; pub struct Message { body: Part, parts: Vec, + fragmented_parts: Vec, is_illegal: bool, } @@ -66,6 +68,7 @@ impl Message { Self { body, parts, + fragmented_parts: vec![], is_illegal: false, } } @@ -87,60 +90,147 @@ impl Message { /// Returns the total size (in bytes) of the message. pub fn len(&self) -> usize { - self.body.len() + self.parts.iter().map(|part| part.len()).sum::() + self.body.len() + + self.parts.iter().map(|part| part.len()).sum::() + + self + .fragmented_parts + .iter() + .map(|fp| fp.len()) + .sum::() } /// Returns whether the message is empty. It is always false, since the body /// is always defined. pub fn is_empty(&self) -> bool { - self.body.is_empty() && self.parts.iter().all(|part| part.is_empty()) + self.body.is_empty() + && self.parts.iter().all(|part| part.is_empty()) + && self.fragmented_parts.iter().all(|fp| fp.is_empty()) } /// Convert this message into its constituent components. - pub fn into_inner(self) -> (Part, Vec) { - (self.body, self.parts) + pub fn into_inner(self) -> (Part, Vec, Vec) { + (self.body, self.parts, self.fragmented_parts) } /// Returns the total size (in bytes) of the message when it is framed. pub fn frame_len(&self) -> usize { - 8 * (1 + self.num_parts()) + self.len() + if self.is_illegal { + // Illegal messages use a simplified frame format: u64::MAX marker + body + return 8 + self.body.len(); + } + + // Headers: body_len (8) + num_regular_parts (8) + num_fragmented (8) + let header_bytes = 3 * 8; + + let body_bytes = self.body.len(); + + let regular_parts_bytes = + self.parts.len() * 8 + self.parts.iter().map(|p| p.len()).sum::(); + + let fragmented_parts_bytes = self.fragmented_parts.len() * 8 + + self + .fragmented_parts + .iter() + .map(|fp| fp.len()) + .sum::(); + + header_bytes + body_bytes + regular_parts_bytes + fragmented_parts_bytes } /// Efficiently frames a message containing the body and all of its parts /// using a simple frame-length encoding: /// /// ```text - /// +--------------------+-------------------+--------------------+-------------------+ ... + - /// | body_len (u64 BE) | body bytes | part1_len (u64 BE) | part1 bytes | | - /// +--------------------+-------------------+--------------------+-------------------+ + - /// repeat - /// for - /// each part + /// ┌─────────────────────────┐ + /// │ body_len (u64 BE) │ + /// ├─────────────────────────┤ + /// │ body bytes │ + /// ├─────────────────────────┤ + /// │ num_parts (u64 BE) │ + /// ├─────────────────────────┤ + /// │ part1_len (u64 BE) │ + /// ├─────────────────────────┤ + /// │ part1 bytes │ + /// ├─────────────────────────┤ + /// │ part2_len (u64 BE) │ + /// ├─────────────────────────┤ + /// │ part2 bytes │ + /// ├─────────────────────────┤ + /// │ ... │ + /// ├─────────────────────────┤ + /// │ num_fragmented (u64 BE) │ + /// ├─────────────────────────┤ + /// │ frag1_len (u64 BE) │ + /// ├─────────────────────────┤ + /// │ frag1 bytes │ + /// ├─────────────────────────┤ + /// │ frag2_len (u64 BE) │ + /// ├─────────────────────────┤ + /// │ frag2 bytes │ + /// ├─────────────────────────┤ + /// │ ... │ + /// └─────────────────────────┘ /// ``` pub fn framed(self) -> Frame { let is_illegal = self.is_illegal; - let (body, parts) = self.into_inner(); + let (body, parts, fragmented_parts) = self.into_inner(); if is_illegal { - assert!(parts.is_empty(), "illegal illegal message"); + assert!( + parts.is_empty() && fragmented_parts.is_empty(), + "illegal illegal message" + ); return Frame::from_buffers(vec![ Bytes::from_owner(u64::MAX.to_be_bytes()), body.into_inner(), ]); } - let mut buffers = Vec::with_capacity(2 + 2 * parts.len()); + let has_fragmented = !fragmented_parts.is_empty(); + + let fragmented_total_parts: usize = + fragmented_parts.iter().map(|fp| fp.as_slice().len()).sum(); + let mut buffers = Vec::with_capacity( + 3 + // body_len + body + num_regular_parts + 2 * parts.len() + // Regular parts (len + data each) + if has_fragmented { 1 + fragmented_parts.len() + fragmented_total_parts } else { 0 }, + ); let body = body.into_inner(); buffers.push(Bytes::from_owner(body.len().to_be_bytes())); buffers.push(body); + // Number of regular parts + buffers.push(Bytes::from_owner(parts.len().to_be_bytes())); + for part in parts { let part = part.into_inner(); + // Length of this part buffers.push(Bytes::from_owner(part.len().to_be_bytes())); + buffers.push(part); } + if has_fragmented { + // Number of FragmentedParts + buffers.push(Bytes::from_owner(fragmented_parts.len().to_be_bytes())); + + for frag_part in fragmented_parts { + let parts = frag_part.into_parts(); + // Length of all parts/fragments + buffers.push(Bytes::from_owner( + (parts.iter().map(|p| p.len()).sum::() as u64).to_be_bytes(), + )); + + for part in parts { + buffers.push(part.into_inner()); + } + } + } else { + // Write 0 for num_fragmented if there are none + buffers.push(Bytes::from_owner(0u64.to_be_bytes())); + } + Frame::from_buffers(buffers) } @@ -154,18 +244,44 @@ impl Message { return Ok(Self { body: buf.into(), parts: vec![], + fragmented_parts: vec![], is_illegal: true, }); } let body = buf.split_to(body_len as usize); - let mut parts = Vec::new(); - while !buf.is_empty() { + + // Read number of regular parts + if buf.len() < 8 { + return Err(std::io::ErrorKind::UnexpectedEof.into()); + } + let num_regular_parts = buf.get_u64() as usize; + + if buf.len() < 8 { + return Err(std::io::ErrorKind::UnexpectedEof.into()); + } + + let mut parts = Vec::with_capacity(num_regular_parts); + for _ in 0..num_regular_parts { parts.push(Self::split_part(&mut buf)?.into()); } + + if buf.len() < 8 { + return Err(std::io::ErrorKind::UnexpectedEof.into()); + } + let num_fragmented = buf.get_u64() as usize; + + let mut fragmented_parts = Vec::with_capacity(num_fragmented); + for _ in 0..num_fragmented { + fragmented_parts.push(FragmentedPart::Contiguous( + Self::split_part(&mut buf)?.into(), + )); + } + Ok(Self { body: body.into(), parts, + fragmented_parts, is_illegal: false, }) } @@ -322,9 +438,11 @@ pub fn serialize_bincode( let mut serializer: part::BincodeSerializer = ser::bincode::Serializer::new(bincode::Serializer::new(buffer_borrow.writer(), options())); value.serialize(&mut serializer)?; + let (parts, fragmented_parts) = serializer.into_parts(); Ok(Message { body: Part(buffer.into_inner().freeze()), - parts: serializer.into_parts(), + parts, + fragmented_parts, is_illegal: false, }) } @@ -336,17 +454,18 @@ where T: serde::de::DeserializeOwned, { if message.is_illegal { - let (body, parts) = message.into_inner(); - if !parts.is_empty() { + let (body, parts, fragmented_parts) = message.into_inner(); + if !parts.is_empty() || !fragmented_parts.is_empty() { return Err(bincode::ErrorKind::Custom("illegal illegal message".to_string()).into()); } return bincode::deserialize_from(body.into_inner().reader()); } - let (body, parts) = message.into_inner(); + let (body, parts, fragmented_parts) = message.into_inner(); let mut deserializer = part::BincodeDeserializer::new( bincode::Deserializer::with_reader(body.into_inner().reader(), options()), parts.into(), + fragmented_parts.into(), ); let value = T::deserialize(&mut deserializer)?; // Check that all parts were consumed: @@ -366,6 +485,7 @@ pub fn serialize_illegal_bincode( Ok(Message { body: Part::from(bincode::serialize(value)?), parts: vec![], + fragmented_parts: vec![], is_illegal: true, }) } @@ -507,6 +627,7 @@ mod tests { let message = Message { body: Part::from("hello"), parts: vec![Part::from("world")], + fragmented_parts: vec![], is_illegal: false, }; let err = deserialize_bincode::(message).unwrap_err(); @@ -565,6 +686,7 @@ mod tests { Part::from("xyz"), Part::from("xyzd"), ], + fragmented_parts: vec![], is_illegal: false, }; @@ -573,6 +695,41 @@ mod tests { assert_eq!(Message::from_framed(framed).unwrap(), message); } + #[test] + fn test_fragmented_part_roundtrip() { + let fragments = vec![ + Part::from("Hello"), + Part::from(" "), + Part::from("World"), + Part::from("!"), + ]; + let expected_data = b"Hello World!"; + + let fragmented = FragmentedPart::Fragmented(fragments); + assert!(matches!(fragmented, FragmentedPart::Fragmented(_))); + + #[derive(Serialize, Deserialize, Debug)] + struct TestStruct { + data: FragmentedPart, + } + + let test_struct = TestStruct { data: fragmented }; + + let message = serialize_bincode(&test_struct).unwrap(); + + let mut framed = message.framed(); + let framed_bytes = framed.copy_to_bytes(framed.remaining()); + + let unframed_message = Message::from_framed(framed_bytes).unwrap(); + + let deserialized: TestStruct = deserialize_bincode(unframed_message).unwrap(); + + assert!(matches!(deserialized.data, FragmentedPart::Contiguous(_))); + + let contiguous_bytes = deserialized.data.into_bytes(); + assert_eq!(&*contiguous_bytes, expected_data); + } + #[test] fn test_socket_addr() { let socket_addr_v6: SocketAddrV6 = diff --git a/serde_multipart/src/part.rs b/serde_multipart/src/part.rs index a90b3358c..8678e1df9 100644 --- a/serde_multipart/src/part.rs +++ b/serde_multipart/src/part.rs @@ -9,6 +9,7 @@ use std::ops::Deref; use bytes::Bytes; +use bytes::BytesMut; use bytes::buf::Reader as BufReader; use bytes::buf::Writer as BufWriter; use serde::Deserialize; @@ -124,3 +125,164 @@ impl<'de, 'a> PartDeserializer<'de, &'a mut BincodeDeserializer> for Part { deserializer.deserialize_part() } } + +/// A logically contiguous part that may be physically fragmented or contiguous. +/// +/// During serialization, parts are extracted separately (allowing zero-copy from construction). +/// During deserialization, data arrives as a single contiguous `Part`. +/// +/// Use this when: +/// - Construction creates multiple Parts (e.g., multiple pickle writes to a Buffer) +/// - Consumption needs contiguous bytes (e.g., unpickling requires contiguous buffer) +/// - Network read already gives contiguous bytes (no need to split and re-concat) +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum FragmentedPart { + /// Multiple fragments that need to be concatenated when accessed + Fragmented(Vec), + /// Already contiguous data (typically from deserialization) + Contiguous(Part), +} + +impl Default for FragmentedPart { + fn default() -> Self { + Self::Contiguous(Part::default()) + } +} + +impl FragmentedPart { + pub fn new(parts: Vec) -> Self { + if parts.len() == 1 { + Self::Contiguous(parts.into_iter().next().unwrap()) + } else { + Self::Fragmented(parts) + } + } + + pub fn into_parts(self) -> Vec { + match self { + Self::Fragmented(parts) => parts, + Self::Contiguous(part) => vec![part], + } + } + + /// Convert into bytes, concatenating fragments if necessary. + pub fn into_bytes(self) -> Bytes { + match self { + Self::Contiguous(part) => part.into_inner(), + Self::Fragmented(parts) => { + let total_len: usize = parts.iter().map(|p| p.len()).sum(); + let mut result = BytesMut::with_capacity(total_len); + for part in parts { + result.extend_from_slice(&part.to_bytes()); + } + result.freeze() + } + } + } + + /// Get bytes as a reference, concatenating fragments if necessary. + pub fn as_bytes(&self) -> Bytes { + match self { + Self::Contiguous(part) => part.to_bytes(), + Self::Fragmented(parts) => { + let total_len: usize = parts.iter().map(|p| p.len()).sum(); + let mut result = BytesMut::with_capacity(total_len); + for part in parts { + result.extend_from_slice(&part.to_bytes()); + } + result.freeze() + } + } + } + + pub fn as_slice(&self) -> &[Part] { + match self { + Self::Fragmented(parts) => parts.as_slice(), + Self::Contiguous(part) => std::slice::from_ref(part), + } + } + + /// Returns the total length in bytes of the fragmented part. + /// For contiguous parts, this is just the part length. + /// For fragmented parts, this is the sum of all fragment lengths. + pub fn len(&self) -> usize { + match self { + Self::Contiguous(part) => part.len(), + Self::Fragmented(parts) => parts.iter().map(|p| p.len()).sum(), + } + } + + /// Returns whether the fragmented part is empty. + pub fn is_empty(&self) -> bool { + self.len() == 0 + } +} + +/// Serialization trait for FragmentedPart (similar to PartSerializer) +trait FragmentedPartSerializer { + fn serialize(parts: &FragmentedPart, s: S) -> Result; +} + +/// Default: serialize as Vec +impl FragmentedPartSerializer for FragmentedPart { + default fn serialize(part: &FragmentedPart, s: S) -> Result { + match part { + FragmentedPart::Fragmented(parts) => parts.serialize(s), + FragmentedPart::Contiguous(part) => vec![part.clone()].serialize(s), + } + } +} + +/// Specialized for our BincodeSerializer +impl<'a> FragmentedPartSerializer<&'a mut BincodeSerializer> for FragmentedPart { + fn serialize( + parts: &FragmentedPart, + s: &'a mut BincodeSerializer, + ) -> Result<(), bincode::Error> { + // Tell the serializer to extract this as a fragmented part + s.serialize_fragmented_part(parts); + // Serialize as empty Vec in the body (parts are extracted) + Vec::::new().serialize(s) + } +} + +impl Serialize for FragmentedPart { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + >::serialize(self, serializer) + } +} + +/// Deserialization trait for FragmentedPart +trait FragmentedPartDeserializer<'de, D: serde::Deserializer<'de>>: Sized { + fn deserialize(d: D) -> Result; +} + +/// Default: deserialize as Vec +impl<'de, D: serde::Deserializer<'de>> FragmentedPartDeserializer<'de, D> for FragmentedPart { + default fn deserialize(deserializer: D) -> Result { + let parts = Vec::::deserialize(deserializer)?; + Ok(Self::new(parts)) + } +} + +/// Specialized for our BincodeDeserializer +impl<'de, 'a> FragmentedPartDeserializer<'de, &'a mut BincodeDeserializer> for FragmentedPart { + fn deserialize(deserializer: &'a mut BincodeDeserializer) -> Result { + // Read the Vec (should be empty from serialization) + let _empty: Vec = Vec::deserialize(&mut *deserializer)?; + // Pull the actual fragmented part from the deserializer + deserializer.deserialize_fragmented_part() + } +} + +impl<'de> Deserialize<'de> for FragmentedPart { + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + >::deserialize(deserializer) + } +} diff --git a/serde_multipart/src/ser/bincode.rs b/serde_multipart/src/ser/bincode.rs index 36b54b753..dbae7e709 100644 --- a/serde_multipart/src/ser/bincode.rs +++ b/serde_multipart/src/ser/bincode.rs @@ -13,13 +13,15 @@ use ::bincode::Options; use serde::Serialize; use serde::ser; +use crate::FragmentedPart; use crate::Part; /// Multipart serializer for bincode. This passes through serialization to bincode, -/// but also records the parts encoded by [`Part::serialize`]. +/// but also records the parts encoded by [`Part::serialize`] and [`FragmentedPart::serialize`]. pub struct Serializer { ser: ::bincode::Serializer, parts: Vec, + fragmented_parts: Vec, } impl Serializer { @@ -27,6 +29,7 @@ impl Serializer { Self { ser, parts: Vec::new(), + fragmented_parts: Vec::new(), } } @@ -35,8 +38,13 @@ impl Serializer { self.parts.push(part.clone()); } - pub(crate) fn into_parts(self) -> Vec { - self.parts + /// Serialize a FragmentedPart by appending it to the fragmented_parts list. + pub(crate) fn serialize_fragmented_part(&mut self, parts: &FragmentedPart) { + self.fragmented_parts.push(parts.clone()); + } + + pub(crate) fn into_parts(self) -> (Vec, Vec) { + (self.parts, self.fragmented_parts) } }