Skip to content

Commit d1fccaf

Browse files
authored
Optimize packing of bytearray (#51)
By special-casing `bytearray`, we can avoid an allocation and complete extra copy of the data when packing it. This speeds up packing by roughly 1/3.
1 parent c4d77ed commit d1fccaf

File tree

2 files changed

+56
-10
lines changed

2 files changed

+56
-10
lines changed

changelog.d/51.improve.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
Optimize packing of `bytearray`<ISSUES_LIST>.
2+
By special-casing `bytearray`, we can avoid an allocation and complete extra copy of the data when packing it.
3+
This speeds up packing of `bytearray`s by roughly 1/3.

src/codec/packstream/v1/pack.rs

Lines changed: 53 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,9 @@ use std::sync::OnceLock;
1818

1919
use pyo3::exceptions::{PyOverflowError, PyTypeError, PyValueError};
2020
use pyo3::prelude::*;
21+
use pyo3::sync::with_critical_section;
2122
use pyo3::sync::OnceLockExt;
22-
use pyo3::types::{PyBytes, PyDict, PyString, PyType};
23+
use pyo3::types::{PyByteArray, PyBytes, PyDict, PyString, PyTuple, PyType};
2324
use pyo3::{intern, IntoPyObjectExt};
2425

2526
use super::super::Structure;
@@ -43,6 +44,35 @@ struct TypeMappings {
4344

4445
impl TypeMappings {
4546
fn new(locals: &Bound<PyDict>) -> PyResult<Self> {
47+
/// Remove some byte types from an iterable of types.
48+
/// Types removed are `bytes`, `bytearray`, as those are handled specially in `pack`.
49+
/// If the filtering fails for any reason, it returns the original input.
50+
fn filter_bytes_types(types: Bound<PyAny>) -> Bound<PyAny> {
51+
fn inner<'py>(types: &Bound<'py, PyAny>) -> PyResult<Bound<'py, PyAny>> {
52+
fn is_of_known_bytes_types(typ: &Bound<PyType>) -> PyResult<bool> {
53+
Ok(typ.is_subclass_of::<PyBytes>()? || typ.is_subclass_of::<PyByteArray>()?)
54+
}
55+
56+
let py = types.py();
57+
let types = types
58+
.try_iter()?
59+
.filter(|typ| {
60+
let Ok(typ) = typ else {
61+
return true;
62+
};
63+
let Ok(typ) = typ.downcast::<PyType>() else {
64+
return true;
65+
};
66+
is_of_known_bytes_types(typ).map(|b| !b).unwrap_or(true)
67+
})
68+
.collect::<Result<Vec<_>, _>>()?;
69+
70+
Ok(PyTuple::new(py, types)?.into_any())
71+
}
72+
73+
inner(&types).unwrap_or(types)
74+
}
75+
4676
let py = locals.py();
4777
Ok(Self {
4878
none_values: locals
@@ -87,12 +117,15 @@ impl TypeMappings {
87117
PyErr::new::<PyValueError, _>("Type mappings are missing MAPPING_TYPES.")
88118
})?
89119
.into_py_any(py)?,
90-
bytes_types: locals
91-
.get_item("BYTES_TYPES")?
92-
.ok_or_else(|| {
93-
PyErr::new::<PyValueError, _>("Type mappings are missing BYTES_TYPES.")
94-
})?
95-
.into_py_any(py)?,
120+
bytes_types: filter_bytes_types(
121+
locals
122+
.get_item("BYTES_TYPES")?
123+
.ok_or_else(|| {
124+
PyErr::new::<PyValueError, _>("Type mappings are missing BYTES_TYPES.")
125+
})?
126+
.into_bound_py_any(py)?,
127+
)
128+
.unbind(),
96129
})
97130
}
98131
}
@@ -170,8 +203,18 @@ impl<'a> PackStreamEncoder<'a> {
170203
return self.write_string(value.extract::<&str>()?);
171204
}
172205

173-
if value.is_instance(self.type_mappings.bytes_types.bind(py))? {
174-
return self.write_bytes(value.extract::<Cow<[u8]>>()?);
206+
if let Ok(value) = value.downcast::<PyBytes>() {
207+
return self.write_bytes(value.as_bytes());
208+
} else if let Ok(value) = value.downcast::<PyByteArray>() {
209+
return with_critical_section(value, || {
210+
// SAFETY:
211+
// * we're holding the GIL/are attached to the Python interpreter
212+
// * we're using a critical section to ensure exclusive access to the byte array
213+
// * we don't interact with the interpreter/PyO3 APIs while reading the bytes
214+
unsafe { self.write_bytes(value.as_bytes()) }
215+
});
216+
} else if value.is_instance(self.type_mappings.bytes_types.bind(py))? {
217+
return self.write_bytes(&value.extract::<Cow<[u8]>>()?);
175218
}
176219

177220
if value.is_instance(self.type_mappings.sequence_types.bind(py))? {
@@ -268,7 +311,7 @@ impl<'a> PackStreamEncoder<'a> {
268311
Ok(())
269312
}
270313

271-
fn write_bytes(&mut self, b: Cow<[u8]>) -> PyResult<()> {
314+
fn write_bytes(&mut self, b: &[u8]) -> PyResult<()> {
272315
let size = Self::usize_to_u64(b.len())?;
273316
if size <= 255 {
274317
self.buffer.extend(&[BYTES_8]);

0 commit comments

Comments
 (0)