Skip to content

Commit 39a4051

Browse files
thomasywangfacebook-github-bot
authored andcommitted
Use PyBytes as backing for Buffer (#1817)
Summary: Currently our pickle is still not truly zero copy because the Pickler calls `Buffer::write()` which is copying bytes from `PyBytes` to `BytesMut` via `extend_from_slice()`. To avoid copies, we can just make `Buffer` backed by a `Vec<PyBytes>` with each call to `Buffer::write()` pushing the PyBytes to the Vec. The following figures show a round trip produced from ``` await am.echo.call(b"x" * 1024 * 1024) ``` Before: Send path: 600us pickle (purple), write frames (dark green), receive frames (light green), 130 us unpickle, Reply path: 600us pickle, write frames, receive frames, 130us unpickle {F1983399784} After: Send path: 20us pickle (purple), write frames (dark green), receive frames (light green), 200us unpickle, Reply path: 20us pickle, write frames, receive frames, 150us unpickle {F1983399596} Differential Revision: D86696391
1 parent a8611d4 commit 39a4051

File tree

12 files changed

+135
-110
lines changed

12 files changed

+135
-110
lines changed

monarch_extension/src/mesh_controller.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ use hyperactor_mesh::v1::Name;
4444
use hyperactor_mesh_macros::sel;
4545
use monarch_hyperactor::actor::PythonMessage;
4646
use monarch_hyperactor::actor::PythonMessageKind;
47-
use monarch_hyperactor::buffers::FrozenBuffer;
47+
use monarch_hyperactor::buffers::Buffer;
4848
use monarch_hyperactor::context::PyInstance;
4949
use monarch_hyperactor::instance_dispatch;
5050
use monarch_hyperactor::local_state_broker::LocalStateBrokerActor;
@@ -571,10 +571,10 @@ impl History {
571571
let exe = remote_exception
572572
.call1((exception.backtrace, traceback, rank))
573573
.unwrap();
574-
let data: FrozenBuffer = pickle.call1((exe,)).unwrap().extract().unwrap();
575-
PythonMessage::new_from_buf(
574+
let mut data: Buffer = pickle.call1((exe,)).unwrap().extract().unwrap();
575+
PythonMessage::new_from_fragmented(
576576
PythonMessageKind::Exception { rank: Some(rank) },
577-
data.inner,
577+
data.into_fragmented_part(),
578578
)
579579
}));
580580

monarch_hyperactor/src/actor.rs

Lines changed: 21 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -47,11 +47,13 @@ use pyo3::types::PyType;
4747
use serde::Deserialize;
4848
use serde::Serialize;
4949
use serde_bytes::ByteBuf;
50+
use serde_multipart::FragmentedPart;
5051
use serde_multipart::Part;
5152
use tokio::sync::Mutex;
5253
use tokio::sync::oneshot;
5354
use tracing::Instrument;
5455

56+
use crate::buffers::Buffer;
5557
use crate::buffers::FrozenBuffer;
5658
use crate::config::SHARED_ASYNCIO_RUNTIME;
5759
use crate::context::PyInstance;
@@ -265,7 +267,7 @@ fn mailbox<'py, T: Actor>(py: Python<'py>, cx: &Context<'_, T>) -> Bound<'py, Py
265267
#[derive(Clone, Serialize, Deserialize, Named, PartialEq, Default)]
266268
pub struct PythonMessage {
267269
pub kind: PythonMessageKind,
268-
pub message: Part,
270+
pub message: FragmentedPart,
269271
}
270272

271273
struct ResolvedCallMethod {
@@ -281,7 +283,14 @@ impl PythonMessage {
281283
pub fn new_from_buf(kind: PythonMessageKind, message: impl Into<Part>) -> Self {
282284
Self {
283285
kind,
284-
message: message.into(),
286+
message: FragmentedPart::Contiguous(message.into()),
287+
}
288+
}
289+
290+
pub fn new_from_fragmented(kind: PythonMessageKind, fragmented_part: FragmentedPart) -> Self {
291+
Self {
292+
kind,
293+
message: fragmented_part,
285294
}
286295
}
287296

@@ -336,7 +345,7 @@ impl PythonMessage {
336345
Ok(ResolvedCallMethod {
337346
method: name,
338347
bytes: FrozenBuffer {
339-
inner: self.message.into_inner(),
348+
inner: self.message.into_bytes(),
340349
},
341350
local_state,
342351
response_port,
@@ -375,7 +384,7 @@ impl PythonMessage {
375384
Ok(ResolvedCallMethod {
376385
method: name,
377386
bytes: FrozenBuffer {
378-
inner: self.message.into_inner(),
387+
inner: self.message.into_bytes(),
379388
},
380389
local_state,
381390
response_port,
@@ -394,7 +403,7 @@ impl std::fmt::Debug for PythonMessage {
394403
.field("kind", &self.kind)
395404
.field(
396405
"message",
397-
&hyperactor::data::HexFmt(&(*self.message)[..]).to_string(),
406+
&hyperactor::data::HexFmt(&(*self.message.as_bytes())[..]).to_string(),
398407
)
399408
.finish()
400409
}
@@ -423,9 +432,11 @@ impl PythonMessage {
423432
#[new]
424433
#[pyo3(signature = (kind, message))]
425434
pub fn new<'py>(kind: PythonMessageKind, message: Bound<'py, PyAny>) -> PyResult<Self> {
426-
if let Ok(buff) = message.extract::<Bound<'py, FrozenBuffer>>() {
427-
let frozen = buff.borrow_mut();
428-
return Ok(PythonMessage::new_from_buf(kind, frozen.inner.clone()));
435+
if let Ok(mut buff) = message.extract::<PyRefMut<'py, Buffer>>() {
436+
return Ok(PythonMessage::new_from_fragmented(
437+
kind,
438+
buff.into_fragmented_part(),
439+
));
429440
} else if let Ok(buff) = message.extract::<Bound<'py, PyBytes>>() {
430441
return Ok(PythonMessage::new_from_buf(
431442
kind,
@@ -446,7 +457,7 @@ impl PythonMessage {
446457
#[getter]
447458
fn message(&self) -> FrozenBuffer {
448459
FrozenBuffer {
449-
inner: self.message.clone().into_inner(),
460+
inner: self.message.as_bytes(),
450461
}
451462
}
452463
}
@@ -1001,7 +1012,7 @@ mod tests {
10011012
},
10021013
response_port: Some(EitherPortRef::Unbounded(port_ref.clone().into())),
10031014
},
1004-
message: Part::from(vec![1, 2, 3]),
1015+
message: FragmentedPart::Contiguous(Part::from(vec![1, 2, 3])),
10051016
};
10061017
{
10071018
let mut erased = ErasedUnbound::try_from_message(message.clone()).unwrap();

monarch_hyperactor/src/buffers.rs

Lines changed: 78 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -12,86 +12,92 @@ use std::ffi::c_int;
1212
use std::ffi::c_void;
1313

1414
use bytes::Buf;
15-
use bytes::BytesMut;
1615
use hyperactor::Named;
1716
use pyo3::buffer::PyBuffer;
1817
use pyo3::prelude::*;
1918
use pyo3::types::PyBytes;
2019
use pyo3::types::PyBytesMethods;
2120
use serde::Deserialize;
2221
use serde::Serialize;
22+
use serde_multipart::FragmentedPart;
23+
use serde_multipart::Part;
24+
25+
/// Wrapper that keeps Py<PyBytes> alive while allowing zero-copy access to its memory
26+
struct PyBytesWrapper {
27+
_py_bytes: Py<PyBytes>,
28+
ptr: *const u8,
29+
len: usize,
30+
}
31+
32+
impl PyBytesWrapper {
33+
fn new(py_bytes: Py<PyBytes>) -> Self {
34+
let (ptr, len) = Python::with_gil(|py| {
35+
let bytes_ref = py_bytes.as_bytes(py);
36+
(bytes_ref.as_ptr(), bytes_ref.len())
37+
});
38+
Self {
39+
_py_bytes: py_bytes,
40+
ptr,
41+
len,
42+
}
43+
}
44+
}
45+
46+
impl AsRef<[u8]> for PyBytesWrapper {
47+
fn as_ref(&self) -> &[u8] {
48+
// SAFETY: ptr is valid as long as py_bytes is alive (kept alive by Py<PyBytes>)
49+
// Python won't free the memory until the Py<PyBytes> refcount reaches 0
50+
unsafe { std::slice::from_raw_parts(self.ptr, self.len) }
51+
}
52+
}
53+
54+
// SAFETY: Py<PyBytes> is Send/Sync for immutable bytes
55+
unsafe impl Send for PyBytesWrapper {}
56+
// SAFETY: Py<PyBytes> is Send/Sync for immutable bytes
57+
unsafe impl Sync for PyBytesWrapper {}
2358

2459
/// A mutable buffer for reading and writing bytes data.
2560
///
26-
/// The `Buffer` struct provides an interface for accumulating byte data that can be written to
27-
/// and then frozen into an immutable `FrozenBuffer` for reading. It uses the `bytes::BytesMut`
28-
/// internally for efficient memory management.
61+
/// The `Buffer` struct provides an interface for accumulating byte data from Python `bytes` objects
62+
/// that can be converted into a `FragmentedPart` for zero-copy multipart message serialization.
63+
/// It accumulates references to Python bytes objects without copying.
2964
///
3065
/// # Examples
3166
///
3267
/// ```python
3368
/// from monarch._rust_bindings.monarch_hyperactor.buffers import Buffer
3469
///
35-
/// # Create a new buffer with default capacity (4096 bytes)
70+
/// # Create a new buffer
3671
/// buffer = Buffer()
3772
///
3873
/// # Write some data
3974
/// data = b"Hello, World!"
4075
/// bytes_written = buffer.write(data)
4176
///
42-
/// # Check length
43-
/// print(len(buffer)) # 13
44-
///
45-
/// # Freeze for reading
46-
/// frozen = buffer.freeze()
47-
/// content = frozen.read()
77+
/// # Use in multipart serialization
78+
/// # The buffer accumulates multiple writes as separate fragments
4879
/// ```
4980
#[pyclass(subclass, module = "monarch._rust_bindings.monarch_hyperactor.buffers")]
50-
#[derive(Clone, Serialize, Deserialize, Named, PartialEq, Default)]
81+
#[derive(Clone, Default)]
5182
pub struct Buffer {
52-
pub(crate) inner: bytes::BytesMut,
53-
}
54-
55-
impl Buffer {
56-
/// Consumes the Buffer and returns the underlying BytesMut.
57-
/// This allows zero-copy access to the raw buffer data.
58-
pub fn into_inner(self) -> bytes::BytesMut {
59-
self.inner
60-
}
61-
}
62-
63-
impl<T> From<T> for Buffer
64-
where
65-
T: Into<BytesMut>,
66-
{
67-
fn from(value: T) -> Self {
68-
Self {
69-
inner: value.into(),
70-
}
71-
}
83+
inner: Vec<Py<PyBytes>>,
7284
}
7385

7486
#[pymethods]
7587
impl Buffer {
7688
/// Creates a new empty buffer with specified initial capacity.
7789
///
78-
/// # Arguments
79-
/// * `size` - Initial capacity in bytes (default: 4096)
8090
///
8191
/// # Returns
8292
/// A new empty `Buffer` instance with the specified capacity.
8393
#[new]
84-
#[pyo3(signature=(size=4096))]
85-
fn new(size: usize) -> Self {
86-
Self {
87-
inner: bytes::BytesMut::with_capacity(size),
88-
}
94+
fn new() -> Self {
95+
Self { inner: Vec::new() }
8996
}
9097

9198
/// Writes bytes data to the buffer.
9299
///
93-
/// Appends the provided bytes to the end of the buffer, extending its capacity
94-
/// if necessary.
100+
/// This keeps a reference to the Python bytes object without copying.
95101
///
96102
/// # Arguments
97103
/// * `buff` - The bytes object to write to the buffer
@@ -100,26 +106,49 @@ impl Buffer {
100106
/// The number of bytes written (always equal to the length of input bytes)
101107
fn write<'py>(&mut self, buff: &Bound<'py, PyBytes>) -> usize {
102108
let bytes_written = buff.as_bytes().len();
103-
self.inner.extend_from_slice(buff.as_bytes());
109+
self.inner.push(buff.clone().unbind());
104110
bytes_written
105111
}
106112

107-
/// Freezes this buffer into an immutable `FrozenBuffer`.
113+
/// Freezes the buffer, converting it into an immutable `FrozenBuffer` for reading.
114+
///
115+
/// This consumes all accumulated PyBytes and converts them into a contiguous bytes buffer.
116+
/// After freezing, the original buffer is cleared.
108117
///
109-
/// This operation consumes the mutable buffer's contents, transferring ownership
110-
/// to a new `FrozenBuffer` that can only be read from. The original buffer
111-
/// becomes empty after this operation.
118+
/// This operation should avoided in hot paths as it creates a copy in order to concatenate
119+
/// bytes that are fragmented in memory into a single series of contiguous bytes
112120
///
113121
/// # Returns
114-
/// A new `FrozenBuffer` containing all the data that was in this buffer
122+
/// A new `FrozenBuffer` containing all the bytes that were written to this buffer
115123
fn freeze(&mut self) -> FrozenBuffer {
116-
let buff = std::mem::take(&mut self.inner);
124+
let fragmented_part = self.into_fragmented_part();
117125
FrozenBuffer {
118-
inner: buff.freeze(),
126+
inner: fragmented_part.into_bytes(),
119127
}
120128
}
121129
}
122130

131+
impl Buffer {
132+
/// Converts accumulated `PyBytes` objects to [`FragmentedPart`] for zero-copy multipart messages.
133+
///
134+
/// Returns a `FragmentedPart::Fragmented` variant since the buffer accumulates multiple
135+
/// separate PyBytes objects that remain physically fragmented.
136+
pub fn into_fragmented_part(&mut self) -> FragmentedPart {
137+
let inner = std::mem::take(&mut self.inner);
138+
139+
FragmentedPart::Fragmented(
140+
inner
141+
.into_iter()
142+
.map(|py_bytes| {
143+
let wrapper = PyBytesWrapper::new(py_bytes);
144+
let bytes = bytes::Bytes::from_owner(wrapper);
145+
Part::from(bytes)
146+
})
147+
.collect(),
148+
)
149+
}
150+
}
151+
123152
/// An immutable buffer for reading bytes data.
124153
///
125154
/// The `FrozenBuffer` struct provides a read-only interface to byte data. Once created,

monarch_tensor_worker/src/stream.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ use hyperactor::mailbox::PortReceiver;
3737
use hyperactor::proc::Proc;
3838
use monarch_hyperactor::actor::PythonMessage;
3939
use monarch_hyperactor::actor::PythonMessageKind;
40-
use monarch_hyperactor::buffers::FrozenBuffer;
40+
use monarch_hyperactor::buffers::Buffer;
4141
use monarch_hyperactor::local_state_broker::BrokerId;
4242
use monarch_hyperactor::local_state_broker::LocalState;
4343
use monarch_hyperactor::local_state_broker::LocalStateBrokerMessage;
@@ -103,16 +103,16 @@ fn pickle_python_result(
103103
.unwrap()
104104
.getattr("_pickle")
105105
.unwrap();
106-
let data: FrozenBuffer = pickle
106+
let mut data: Buffer = pickle
107107
.call1((result,))
108108
.map_err(|pyerr| anyhow::Error::from(SerializablePyErr::from(py, &pyerr)))?
109109
.extract()
110110
.unwrap();
111-
Ok(PythonMessage::new_from_buf(
111+
Ok(PythonMessage::new_from_fragmented(
112112
PythonMessageKind::Result {
113113
rank: Some(worker_rank),
114114
},
115-
data.inner,
115+
data.into_fragmented_part(),
116116
))
117117
}
118118

python/monarch/_rust_bindings/monarch_hyperactor/actor.pyi

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ from typing import (
2323
Union,
2424
)
2525

26-
from monarch._rust_bindings.monarch_hyperactor.buffers import FrozenBuffer
26+
from monarch._rust_bindings.monarch_hyperactor.buffers import Buffer, FrozenBuffer
2727

2828
from monarch._rust_bindings.monarch_hyperactor.mailbox import (
2929
Mailbox,
@@ -204,7 +204,7 @@ class PythonMessage:
204204
def __init__(
205205
self,
206206
kind: PythonMessageKind,
207-
message: Union[FrozenBuffer, bytes],
207+
message: Union[Buffer, bytes],
208208
) -> None: ...
209209
@property
210210
def message(self) -> FrozenBuffer:

0 commit comments

Comments
 (0)