@@ -18,8 +18,9 @@ use std::sync::OnceLock;
1818
1919use pyo3:: exceptions:: { PyOverflowError , PyTypeError , PyValueError } ;
2020use pyo3:: prelude:: * ;
21+ use pyo3:: sync:: with_critical_section;
2122use pyo3:: sync:: OnceLockExt ;
22- use pyo3:: types:: { PyBytes , PyDict , PyString , PyType } ;
23+ use pyo3:: types:: { PyByteArray , PyBytes , PyDict , PyString , PyTuple , PyType } ;
2324use pyo3:: { intern, IntoPyObjectExt } ;
2425
2526use super :: super :: Structure ;
@@ -43,6 +44,35 @@ struct TypeMappings {
4344
4445impl TypeMappings {
4546 fn new ( locals : & Bound < PyDict > ) -> PyResult < Self > {
47+ /// Remove some byte types from an iterable of types.
48+ /// Types removed are `bytes`, `bytearray`, as those are handled specially in `pack`.
49+ /// If the filtering fails for any reason, it returns the original input.
50+ fn filter_bytes_types ( types : Bound < PyAny > ) -> Bound < PyAny > {
51+ fn inner < ' py > ( types : & Bound < ' py , PyAny > ) -> PyResult < Bound < ' py , PyAny > > {
52+ fn is_of_known_bytes_types ( typ : & Bound < PyType > ) -> PyResult < bool > {
53+ Ok ( typ. is_subclass_of :: < PyBytes > ( ) ? || typ. is_subclass_of :: < PyByteArray > ( ) ?)
54+ }
55+
56+ let py = types. py ( ) ;
57+ let types = types
58+ . try_iter ( ) ?
59+ . filter ( |typ| {
60+ let Ok ( typ) = typ else {
61+ return true ;
62+ } ;
63+ let Ok ( typ) = typ. downcast :: < PyType > ( ) else {
64+ return true ;
65+ } ;
66+ is_of_known_bytes_types ( typ) . map ( |b| !b) . unwrap_or ( true )
67+ } )
68+ . collect :: < Result < Vec < _ > , _ > > ( ) ?;
69+
70+ Ok ( PyTuple :: new ( py, types) ?. into_any ( ) )
71+ }
72+
73+ inner ( & types) . unwrap_or ( types)
74+ }
75+
4676 let py = locals. py ( ) ;
4777 Ok ( Self {
4878 none_values : locals
@@ -87,12 +117,15 @@ impl TypeMappings {
87117 PyErr :: new :: < PyValueError , _ > ( "Type mappings are missing MAPPING_TYPES." )
88118 } ) ?
89119 . into_py_any ( py) ?,
90- bytes_types : locals
91- . get_item ( "BYTES_TYPES" ) ?
92- . ok_or_else ( || {
93- PyErr :: new :: < PyValueError , _ > ( "Type mappings are missing BYTES_TYPES." )
94- } ) ?
95- . into_py_any ( py) ?,
120+ bytes_types : filter_bytes_types (
121+ locals
122+ . get_item ( "BYTES_TYPES" ) ?
123+ . ok_or_else ( || {
124+ PyErr :: new :: < PyValueError , _ > ( "Type mappings are missing BYTES_TYPES." )
125+ } ) ?
126+ . into_bound_py_any ( py) ?,
127+ )
128+ . unbind ( ) ,
96129 } )
97130 }
98131}
@@ -170,8 +203,18 @@ impl<'a> PackStreamEncoder<'a> {
170203 return self . write_string ( value. extract :: < & str > ( ) ?) ;
171204 }
172205
173- if value. is_instance ( self . type_mappings . bytes_types . bind ( py) ) ? {
174- return self . write_bytes ( value. extract :: < Cow < [ u8 ] > > ( ) ?) ;
206+ if let Ok ( value) = value. downcast :: < PyBytes > ( ) {
207+ return self . write_bytes ( value. as_bytes ( ) ) ;
208+ } else if let Ok ( value) = value. downcast :: < PyByteArray > ( ) {
209+ return with_critical_section ( value, || {
210+ // SAFETY:
211+ // * we're holding the GIL/are attached to the Python interpreter
212+ // * we're using a critical section to ensure exclusive access to the byte array
213+ // * we don't interact with the interpreter/PyO3 APIs while reading the bytes
214+ unsafe { self . write_bytes ( value. as_bytes ( ) ) }
215+ } ) ;
216+ } else if value. is_instance ( self . type_mappings . bytes_types . bind ( py) ) ? {
217+ return self . write_bytes ( & value. extract :: < Cow < [ u8 ] > > ( ) ?) ;
175218 }
176219
177220 if value. is_instance ( self . type_mappings . sequence_types . bind ( py) ) ? {
@@ -268,7 +311,7 @@ impl<'a> PackStreamEncoder<'a> {
268311 Ok ( ( ) )
269312 }
270313
271- fn write_bytes ( & mut self , b : Cow < [ u8 ] > ) -> PyResult < ( ) > {
314+ fn write_bytes ( & mut self , b : & [ u8 ] ) -> PyResult < ( ) > {
272315 let size = Self :: usize_to_u64 ( b. len ( ) ) ?;
273316 if size <= 255 {
274317 self . buffer . extend ( & [ BYTES_8 ] ) ;
0 commit comments