Skip to content

Commit c4d77ed

Browse files
authored
Simplify packer module init (#52)
Using `OnceLockExt` instead of `GILOnceCell` for simpler code and fewer caveats. See also: https://pyo3.rs/v0.25.1/free-threading.html?highlight=free#thread-safe-single-initialization
1 parent 2a0152f commit c4d77ed

File tree

1 file changed

+13
-23
lines changed

1 file changed

+13
-23
lines changed

src/codec/packstream/v1/pack.rs

Lines changed: 13 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,11 @@
1414
// limitations under the License.
1515

1616
use std::borrow::Cow;
17-
use std::sync::atomic::{AtomicBool, Ordering};
17+
use std::sync::OnceLock;
1818

19-
use pyo3::exceptions::{PyImportError, PyOverflowError, PyTypeError, PyValueError};
19+
use pyo3::exceptions::{PyOverflowError, PyTypeError, PyValueError};
2020
use pyo3::prelude::*;
21-
use pyo3::sync::GILOnceCell;
21+
use pyo3::sync::OnceLockExt;
2222
use pyo3::types::{PyBytes, PyDict, PyString, PyType};
2323
use pyo3::{intern, IntoPyObjectExt};
2424

@@ -97,29 +97,19 @@ impl TypeMappings {
9797
}
9898
}
9999

100-
static TYPE_MAPPINGS: GILOnceCell<PyResult<TypeMappings>> = GILOnceCell::new();
101-
static TYPE_MAPPINGS_INIT: AtomicBool = AtomicBool::new(false);
100+
static TYPE_MAPPINGS: OnceLock<PyResult<TypeMappings>> = OnceLock::new();
102101

103102
fn get_type_mappings(py: Python<'_>) -> PyResult<&'static TypeMappings> {
104-
let mappings = TYPE_MAPPINGS.get_or_try_init(py, || {
105-
fn init(py: Python<'_>) -> PyResult<TypeMappings> {
106-
let locals = PyDict::new(py);
107-
py.run(
108-
c"from neo4j._codec.packstream.v1.types import *",
109-
None,
110-
Some(&locals),
111-
)?;
112-
TypeMappings::new(&locals)
113-
}
114-
115-
if TYPE_MAPPINGS_INIT.swap(true, Ordering::SeqCst) {
116-
return Err(PyErr::new::<PyImportError, _>(
117-
"Cannot call _rust.pack while loading `neo4j._codec.packstream.v1.types`",
118-
));
119-
}
120-
Ok(init(py))
103+
let mappings = TYPE_MAPPINGS.get_or_init_py_attached(py, || {
104+
let locals = PyDict::new(py);
105+
py.run(
106+
c"from neo4j._codec.packstream.v1.types import *",
107+
None,
108+
Some(&locals),
109+
)?;
110+
TypeMappings::new(&locals)
121111
});
122-
mappings?.as_ref().map_err(|e| e.clone_ref(py))
112+
mappings.as_ref().map_err(|e| e.clone_ref(py))
123113
}
124114

125115
#[pyfunction]

0 commit comments

Comments
 (0)