Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions monarch_extension/src/mesh_controller.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ use hyperactor_mesh::v1::Name;
use hyperactor_mesh_macros::sel;
use monarch_hyperactor::actor::PythonMessage;
use monarch_hyperactor::actor::PythonMessageKind;
use monarch_hyperactor::buffers::FrozenBuffer;
use monarch_hyperactor::buffers::Buffer;
use monarch_hyperactor::context::PyInstance;
use monarch_hyperactor::instance_dispatch;
use monarch_hyperactor::local_state_broker::LocalStateBrokerActor;
Expand Down Expand Up @@ -571,10 +571,10 @@ impl History {
let exe = remote_exception
.call1((exception.backtrace, traceback, rank))
.unwrap();
let data: FrozenBuffer = pickle.call1((exe,)).unwrap().extract().unwrap();
PythonMessage::new_from_buf(
let mut data: Buffer = pickle.call1((exe,)).unwrap().extract().unwrap();
PythonMessage::new_from_fragmented(
PythonMessageKind::Exception { rank: Some(rank) },
data.inner,
data.into_fragmented_part(),
)
}));

Expand Down
31 changes: 21 additions & 10 deletions monarch_hyperactor/src/actor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,11 +47,13 @@ use pyo3::types::PyType;
use serde::Deserialize;
use serde::Serialize;
use serde_bytes::ByteBuf;
use serde_multipart::FragmentedPart;
use serde_multipart::Part;
use tokio::sync::Mutex;
use tokio::sync::oneshot;
use tracing::Instrument;

use crate::buffers::Buffer;
use crate::buffers::FrozenBuffer;
use crate::config::SHARED_ASYNCIO_RUNTIME;
use crate::context::PyInstance;
Expand Down Expand Up @@ -265,7 +267,7 @@ fn mailbox<'py, T: Actor>(py: Python<'py>, cx: &Context<'_, T>) -> Bound<'py, Py
#[derive(Clone, Serialize, Deserialize, Named, PartialEq, Default)]
pub struct PythonMessage {
pub kind: PythonMessageKind,
pub message: Part,
pub message: FragmentedPart,
}

struct ResolvedCallMethod {
Expand All @@ -281,7 +283,14 @@ impl PythonMessage {
pub fn new_from_buf(kind: PythonMessageKind, message: impl Into<Part>) -> Self {
Self {
kind,
message: message.into(),
message: FragmentedPart::Contiguous(message.into()),
}
}

pub fn new_from_fragmented(kind: PythonMessageKind, fragmented_part: FragmentedPart) -> Self {
Self {
kind,
message: fragmented_part,
}
}

Expand Down Expand Up @@ -336,7 +345,7 @@ impl PythonMessage {
Ok(ResolvedCallMethod {
method: name,
bytes: FrozenBuffer {
inner: self.message.into_inner(),
inner: self.message.into_bytes(),
},
local_state,
response_port,
Expand Down Expand Up @@ -375,7 +384,7 @@ impl PythonMessage {
Ok(ResolvedCallMethod {
method: name,
bytes: FrozenBuffer {
inner: self.message.into_inner(),
inner: self.message.into_bytes(),
},
local_state,
response_port,
Expand All @@ -394,7 +403,7 @@ impl std::fmt::Debug for PythonMessage {
.field("kind", &self.kind)
.field(
"message",
&hyperactor::data::HexFmt(&(*self.message)[..]).to_string(),
&hyperactor::data::HexFmt(&(*self.message.as_bytes())[..]).to_string(),
)
.finish()
}
Expand Down Expand Up @@ -423,9 +432,11 @@ impl PythonMessage {
#[new]
#[pyo3(signature = (kind, message))]
pub fn new<'py>(kind: PythonMessageKind, message: Bound<'py, PyAny>) -> PyResult<Self> {
if let Ok(buff) = message.extract::<Bound<'py, FrozenBuffer>>() {
let frozen = buff.borrow_mut();
return Ok(PythonMessage::new_from_buf(kind, frozen.inner.clone()));
if let Ok(mut buff) = message.extract::<PyRefMut<'py, Buffer>>() {
return Ok(PythonMessage::new_from_fragmented(
kind,
buff.into_fragmented_part(),
));
} else if let Ok(buff) = message.extract::<Bound<'py, PyBytes>>() {
return Ok(PythonMessage::new_from_buf(
kind,
Expand All @@ -446,7 +457,7 @@ impl PythonMessage {
#[getter]
fn message(&self) -> FrozenBuffer {
FrozenBuffer {
inner: self.message.clone().into_inner(),
inner: self.message.as_bytes(),
}
}
}
Expand Down Expand Up @@ -1001,7 +1012,7 @@ mod tests {
},
response_port: Some(EitherPortRef::Unbounded(port_ref.clone().into())),
},
message: Part::from(vec![1, 2, 3]),
message: FragmentedPart::Contiguous(Part::from(vec![1, 2, 3])),
};
{
let mut erased = ErasedUnbound::try_from_message(message.clone()).unwrap();
Expand Down
127 changes: 78 additions & 49 deletions monarch_hyperactor/src/buffers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,86 +12,92 @@ use std::ffi::c_int;
use std::ffi::c_void;

use bytes::Buf;
use bytes::BytesMut;
use hyperactor::Named;
use pyo3::buffer::PyBuffer;
use pyo3::prelude::*;
use pyo3::types::PyBytes;
use pyo3::types::PyBytesMethods;
use serde::Deserialize;
use serde::Serialize;
use serde_multipart::FragmentedPart;
use serde_multipart::Part;

/// Wrapper that keeps Py<PyBytes> alive while allowing zero-copy access to its memory
struct PyBytesWrapper {
_py_bytes: Py<PyBytes>,
ptr: *const u8,
len: usize,
}

impl PyBytesWrapper {
fn new(py_bytes: Py<PyBytes>) -> Self {
let (ptr, len) = Python::with_gil(|py| {
let bytes_ref = py_bytes.as_bytes(py);
(bytes_ref.as_ptr(), bytes_ref.len())
});
Self {
_py_bytes: py_bytes,
ptr,
len,
}
}
}

impl AsRef<[u8]> for PyBytesWrapper {
fn as_ref(&self) -> &[u8] {
// SAFETY: ptr is valid as long as py_bytes is alive (kept alive by Py<PyBytes>)
// Python won't free the memory until the Py<PyBytes> refcount reaches 0
unsafe { std::slice::from_raw_parts(self.ptr, self.len) }
}
}

// SAFETY: Py<PyBytes> is Send/Sync for immutable bytes
unsafe impl Send for PyBytesWrapper {}
// SAFETY: Py<PyBytes> is Send/Sync for immutable bytes
unsafe impl Sync for PyBytesWrapper {}

/// A mutable buffer for reading and writing bytes data.
///
/// The `Buffer` struct provides an interface for accumulating byte data that can be written to
/// and then frozen into an immutable `FrozenBuffer` for reading. It uses the `bytes::BytesMut`
/// internally for efficient memory management.
/// The `Buffer` struct provides an interface for accumulating byte data from Python `bytes` objects
/// that can be converted into a `FragmentedPart` for zero-copy multipart message serialization.
/// It accumulates references to Python bytes objects without copying.
///
/// # Examples
///
/// ```python
/// from monarch._rust_bindings.monarch_hyperactor.buffers import Buffer
///
/// # Create a new buffer with default capacity (4096 bytes)
/// # Create a new buffer
/// buffer = Buffer()
///
/// # Write some data
/// data = b"Hello, World!"
/// bytes_written = buffer.write(data)
///
/// # Check length
/// print(len(buffer)) # 13
///
/// # Freeze for reading
/// frozen = buffer.freeze()
/// content = frozen.read()
/// # Use in multipart serialization
/// # The buffer accumulates multiple writes as separate fragments
/// ```
#[pyclass(subclass, module = "monarch._rust_bindings.monarch_hyperactor.buffers")]
#[derive(Clone, Serialize, Deserialize, Named, PartialEq, Default)]
#[derive(Clone, Default)]
pub struct Buffer {
pub(crate) inner: bytes::BytesMut,
}

impl Buffer {
/// Consumes the Buffer and returns the underlying BytesMut.
/// This allows zero-copy access to the raw buffer data.
pub fn into_inner(self) -> bytes::BytesMut {
self.inner
}
}

impl<T> From<T> for Buffer
where
T: Into<BytesMut>,
{
fn from(value: T) -> Self {
Self {
inner: value.into(),
}
}
inner: Vec<Py<PyBytes>>,
}

#[pymethods]
impl Buffer {
/// Creates a new empty buffer with specified initial capacity.
///
/// # Arguments
/// * `size` - Initial capacity in bytes (default: 4096)
///
/// # Returns
/// A new empty `Buffer` instance with the specified capacity.
#[new]
#[pyo3(signature=(size=4096))]
fn new(size: usize) -> Self {
Self {
inner: bytes::BytesMut::with_capacity(size),
}
fn new() -> Self {
Self { inner: Vec::new() }
}

/// Writes bytes data to the buffer.
///
/// Appends the provided bytes to the end of the buffer, extending its capacity
/// if necessary.
/// This keeps a reference to the Python bytes object without copying.
///
/// # Arguments
/// * `buff` - The bytes object to write to the buffer
Expand All @@ -100,26 +106,49 @@ impl Buffer {
/// The number of bytes written (always equal to the length of input bytes)
fn write<'py>(&mut self, buff: &Bound<'py, PyBytes>) -> usize {
let bytes_written = buff.as_bytes().len();
self.inner.extend_from_slice(buff.as_bytes());
self.inner.push(buff.clone().unbind());
bytes_written
}

/// Freezes this buffer into an immutable `FrozenBuffer`.
/// Freezes the buffer, converting it into an immutable `FrozenBuffer` for reading.
///
/// This consumes all accumulated PyBytes and converts them into a contiguous bytes buffer.
/// After freezing, the original buffer is cleared.
///
/// This operation consumes the mutable buffer's contents, transferring ownership
/// to a new `FrozenBuffer` that can only be read from. The original buffer
/// becomes empty after this operation.
/// This operation should avoided in hot paths as it creates a copy in order to concatenate
/// bytes that are fragmented in memory into a single series of contiguous bytes
///
/// # Returns
/// A new `FrozenBuffer` containing all the data that was in this buffer
/// A new `FrozenBuffer` containing all the bytes that were written to this buffer
fn freeze(&mut self) -> FrozenBuffer {
let buff = std::mem::take(&mut self.inner);
let fragmented_part = self.into_fragmented_part();
FrozenBuffer {
inner: buff.freeze(),
inner: fragmented_part.into_bytes(),
}
}
}

impl Buffer {
/// Converts accumulated `PyBytes` objects to [`FragmentedPart`] for zero-copy multipart messages.
///
/// Returns a `FragmentedPart::Fragmented` variant since the buffer accumulates multiple
/// separate PyBytes objects that remain physically fragmented.
pub fn into_fragmented_part(&mut self) -> FragmentedPart {
let inner = std::mem::take(&mut self.inner);

FragmentedPart::Fragmented(
inner
.into_iter()
.map(|py_bytes| {
let wrapper = PyBytesWrapper::new(py_bytes);
let bytes = bytes::Bytes::from_owner(wrapper);
Part::from(bytes)
})
.collect(),
)
}
}

/// An immutable buffer for reading bytes data.
///
/// The `FrozenBuffer` struct provides a read-only interface to byte data. Once created,
Expand Down
8 changes: 4 additions & 4 deletions monarch_tensor_worker/src/stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ use hyperactor::mailbox::PortReceiver;
use hyperactor::proc::Proc;
use monarch_hyperactor::actor::PythonMessage;
use monarch_hyperactor::actor::PythonMessageKind;
use monarch_hyperactor::buffers::FrozenBuffer;
use monarch_hyperactor::buffers::Buffer;
use monarch_hyperactor::local_state_broker::BrokerId;
use monarch_hyperactor::local_state_broker::LocalState;
use monarch_hyperactor::local_state_broker::LocalStateBrokerMessage;
Expand Down Expand Up @@ -103,16 +103,16 @@ fn pickle_python_result(
.unwrap()
.getattr("_pickle")
.unwrap();
let data: FrozenBuffer = pickle
let mut data: Buffer = pickle
.call1((result,))
.map_err(|pyerr| anyhow::Error::from(SerializablePyErr::from(py, &pyerr)))?
.extract()
.unwrap();
Ok(PythonMessage::new_from_buf(
Ok(PythonMessage::new_from_fragmented(
PythonMessageKind::Result {
rank: Some(worker_rank),
},
data.inner,
data.into_fragmented_part(),
))
}

Expand Down
4 changes: 2 additions & 2 deletions python/monarch/_rust_bindings/monarch_hyperactor/actor.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ from typing import (
Union,
)

from monarch._rust_bindings.monarch_hyperactor.buffers import FrozenBuffer
from monarch._rust_bindings.monarch_hyperactor.buffers import Buffer, FrozenBuffer

from monarch._rust_bindings.monarch_hyperactor.mailbox import (
Mailbox,
Expand Down Expand Up @@ -204,7 +204,7 @@ class PythonMessage:
def __init__(
self,
kind: PythonMessageKind,
message: Union[FrozenBuffer, bytes],
message: Union[Buffer, bytes],
) -> None: ...
@property
def message(self) -> FrozenBuffer:
Expand Down
Loading