diff --git a/newsfragments/5600.added.md b/newsfragments/5600.added.md new file mode 100644 index 00000000000..d88b53de6f7 --- /dev/null +++ b/newsfragments/5600.added.md @@ -0,0 +1 @@ +added per-module ModuleState, init + free lifecycle hooks and getters / setters in PyModuleMethods diff --git a/pyo3-macros-backend/src/module.rs b/pyo3-macros-backend/src/module.rs index a876a724b9d..d0bac8ac77e 100644 --- a/pyo3-macros-backend/src/module.rs +++ b/pyo3-macros-backend/src/module.rs @@ -531,7 +531,8 @@ fn module_initialization( #pyo3_path::impl_::trampoline::module_exec(module, #module_exec) } - static SLOTS: impl_::PyModuleSlots<4> = impl_::PyModuleSlotsBuilder::new() + static SLOTS: impl_::PyModuleSlots<5> = impl_::PyModuleSlotsBuilder::new() + .with_mod_exec(impl_::pyo3_module_state_init) .with_mod_exec(__pyo3_module_exec) .with_gil_used(#gil_used) .build(); diff --git a/src/impl_.rs b/src/impl_.rs index 364f43ca4f8..72ccc15328c 100644 --- a/src/impl_.rs +++ b/src/impl_.rs @@ -28,3 +28,5 @@ pub mod pymodule; pub mod trampoline; pub mod unindent; pub mod wrap; + +pub(crate) mod pymodule_state; diff --git a/src/impl_/pymodule.rs b/src/impl_/pymodule.rs index b1bf6fb8878..572217be2fb 100644 --- a/src/impl_/pymodule.rs +++ b/src/impl_/pymodule.rs @@ -4,6 +4,7 @@ use std::{ cell::UnsafeCell, ffi::CStr, marker::PhantomData, + mem::MaybeUninit, os::raw::{c_int, c_void}, }; @@ -24,10 +25,12 @@ use std::sync::atomic::{AtomicI64, Ordering}; #[cfg(not(any(PyPy, GraalPy)))] use crate::exceptions::PyImportError; +use crate::exceptions::PyRuntimeError; +use crate::impl_::trampoline::trampoline; use crate::prelude::PyTypeMethods; use crate::{ ffi, - impl_::pyfunction::PyFunctionDef, + impl_::{pyfunction::PyFunctionDef, pymodule_state::ModuleState}, sync::PyOnceLock, types::{any::PyAnyMethods, dict::PyDictMethods, PyDict, PyModule, PyModuleMethods}, Bound, Py, PyAny, PyClass, PyResult, PyTypeInfo, Python, @@ -77,9 +80,11 @@ impl ModuleDef { let ffi_def = UnsafeCell::new(ffi::PyModuleDef { m_name: name.as_ptr(), m_doc: doc.as_ptr(), + m_size: std::mem::size_of::() as _, // TODO: would be slightly nicer to use `[T]::as_mut_ptr()` here, // but that requires mut ptr deref on MSRV. m_slots: slots.0.get() as _, + m_free: Some(pyo3_module_state_free), ..INIT }); @@ -308,6 +313,40 @@ impl PyAddToModule for ModuleDef { } } +/// Called during multi-phase initialization in order to create an instance of +/// ModuleState on the memory area specific to modules. +/// +/// Slot: [`Py_mod_exec`] +/// +/// [`Py_mod_exec`]: https://docs.python.org/3/c-api/module.html#c.Py_mod_exec +pub unsafe extern "C" fn pyo3_module_state_init(module: *mut ffi::PyObject) -> c_int { + unsafe { + trampoline(|_| { + let state: *mut MaybeUninit = ffi::PyModule_GetState(module).cast(); + + // CPython builtins just assert this, but cross ffi panics are tricky, so we return an + // error instead + if state.is_null() { + return Err(PyRuntimeError::new_err("PyO3 per-module state was null. This is a bug in the Python interpreter runtime.")); + } + + (*state).write(ModuleState::new()); + + Ok(0) + }) + } +} + +/// Called during deallocation of the module object. +/// +/// Used for the [`m_free`] field of [`PyModuleDef`]. +/// +/// [`m_free`]: https://docs.python.org/3/c-api/module.html#c.PyModuleDef.m_free +/// [`PyModuleDef`]: https://docs.python.org/3/c-api/module.html#c.PyModuleDef +pub unsafe extern "C" fn pyo3_module_state_free(module: *mut c_void) { + unsafe { ModuleState::pymodule_free_state(module.cast()) }; +} + #[cfg(test)] mod tests { use std::{borrow::Cow, ffi::CStr, os::raw::c_int}; @@ -386,6 +425,63 @@ mod tests { } } + #[test] + fn module_state_init() { + use super::{pyo3_module_state_init, ModuleState}; + use crate::{PyAny, PyErr}; + + #[derive(Debug, Clone, Copy, PartialEq, Eq)] + struct UserState(u64); + + unsafe extern "C" fn state_test(module: *mut ffi::PyObject) -> c_int { + unsafe { + trampoline::module_exec(module, |module| { + match ModuleState::pymodule_get_state(module.as_ptr()) { + Some(_) => Ok(()), + None => Err(PyErr::new::("failed to initialize ModuleState")), + } + }) + } + } + + static SLOTS: PyModuleSlots<5> = PyModuleSlotsBuilder::new() + .with_gil_used(false) + .with_mod_exec(pyo3_module_state_init) + .with_mod_exec(state_test) + .build(); + static MODULE_DEF: ModuleDef = ModuleDef::new( + c"test_module_state_init", + c"This test is for checking PyO3 ModuleState is initialized correctly", + &SLOTS, + ); + + Python::attach(|py| { + let mut module = MODULE_DEF + .make_module(py) + .expect("module to initialize without error") + .into_bound(py); + let mystate = UserState(42); + + assert_eq!( + None, + module.state_ref::(), + "no state has been added yet" + ); + + assert_eq!( + mystate, + *module.state_or_init(|| mystate), + "added state successfully" + ); + + assert_eq!( + Some(&mystate), + module.state_ref::(), + "previously added state is referenceable" + ); + }) + } + #[test] #[should_panic] fn test_module_slots_builder_overflow() { diff --git a/src/impl_/pymodule_state.rs b/src/impl_/pymodule_state.rs new file mode 100644 index 00000000000..dd6fa9e0ddf --- /dev/null +++ b/src/impl_/pymodule_state.rs @@ -0,0 +1,192 @@ +use std::ptr::NonNull; + +use crate::internal::typemap::{CloneAny, TypeMap}; +use crate::types::PyModule; +use crate::{ffi, Bound}; + +/// The internal typemap for [`ModuleState`] +pub type StateMap = TypeMap; + +/// A marker trait for indicating what type level guarantees (and requirements) +/// are made for PyO3 `PyModule` state types. +/// +/// In general, a type *must be* +/// +/// 1. Fully owned (`'static`) +/// 2. Cloneable (`Clone`) +/// 3. Sendable (`Send`) +/// +/// To qualify as `PyModule` state. +/// +/// This type is automatically implemented for all types that qualify, so no +/// further action is required. +pub trait ModuleStateType: Clone + Send {} +impl ModuleStateType for T {} + +/// Represents a Python module's state. +/// +/// More precisely, this `struct` resides on the per-module memory area +/// allocated during the module's creation. +#[repr(C)] +#[derive(Debug)] +pub struct ModuleState { + inner: Option>, +} + +impl ModuleState { + /// Create a new, empty [`ModuleState`] + pub fn new() -> Self { + let boxed = Box::new(StateCapsule::new()); + + Self { + inner: NonNull::new(Box::into_raw(boxed)), + } + } + + pub fn state_map_ref(&self) -> &StateMap { + &self.inner_ref().sm + } + + pub fn state_map_mut(&mut self) -> &mut StateMap { + &mut self.inner_mut().sm + } + + fn inner_ref(&self) -> &StateCapsule { + self.inner + .as_ref() + .map(|ptr| unsafe { ptr.as_ref() }) + .expect("BUG: ModuleState.inner should always be Some, except when dropping") + } + + fn inner_mut(&mut self) -> &mut StateCapsule { + self.inner + .as_mut() + .map(|ptr| unsafe { ptr.as_mut() }) + .expect("BUG: ModuleState.inner should always be Some, except when dropping") + } + + /// This is the actual [`Drop::drop`] implementation, split out + /// so we can run it on the state ptr returned from [`Self::pymodule_get_state`] + /// + /// While this function does not take a owned `self`, the calling ModuleState + /// should not be accessed again + /// + /// Calling this function multiple times on a single ModuleState is a noop, + /// beyond the first + unsafe fn drop_impl(&mut self) { + if let Some(ptr) = self.inner.take().map(|state| state.as_ptr()) { + // SAFETY: This ptr is allocated via Box::new in Self::new, and is + // non null + unsafe { drop(Box::from_raw(ptr)) } + } + } +} + +impl ModuleState { + /// Fetch the [`ModuleState`] from a bound PyModule, inheriting it's lifetime + /// + /// ## Panics + /// + /// This function can panic if called on a PyModule that has not yet been + /// initialized + pub(crate) fn from_bound<'a>(this: &'a Bound<'_, PyModule>) -> &'a Self { + unsafe { + Self::pymodule_get_state(this.as_ptr()) + .map(|ptr| ptr.as_ref()) + .expect("pyo3 PyModules should always have per-module state") + } + } + + /// Fetch the [`ModuleState`] mutably from a bound PyModule, inheriting it's + /// lifetime + /// + /// ## Panics + /// + /// This function can panic if called on a PyModule that has not yet been + /// initialized + pub(crate) fn from_bound_mut<'a>(this: &'a mut Bound<'_, PyModule>) -> &'a mut Self { + unsafe { + Self::pymodule_get_state(this.as_ptr()) + .map(|mut ptr| ptr.as_mut()) + .expect("pyo3 PyModules should always have per-module state") + } + } + + /// Associated low level function for retrieving a pyo3 `pymodule`'s state + /// + /// If this function returns None, it means the underlying C PyModule does + /// not have module state. + /// + /// This function should only be called on a PyModule that is already + /// initialized via PyModule_New (or Py_mod_create) + pub(crate) unsafe fn pymodule_get_state(module: *mut ffi::PyObject) -> Option> { + unsafe { + let state: *mut ModuleState = ffi::PyModule_GetState(module).cast(); + + match state.is_null() { + true => None, + false => Some(NonNull::new_unchecked(state)), + } + } + } + + /// Associated low level function for freeing our `pymodule`'s state + /// via a ModuleDef's m_free C callback + pub(crate) unsafe fn pymodule_free_state(module: *mut ffi::PyObject) { + unsafe { + if let Some(state) = Self::pymodule_get_state(module) { + // SAFETY: this callback is called when python is freeing the + // associated PyModule, so we should never be accessed again + (*state.as_ptr()).drop_impl() + } + } + } +} + +impl Drop for ModuleState { + fn drop(&mut self) { + // SAFETY: we're being dropped, so we'll never be accessed again + unsafe { self.drop_impl() }; + } +} + +impl Default for ModuleState { + fn default() -> Self { + Self::new() + } +} + +/// Inner layout of [`ModuleState`]. +#[derive(Debug, Clone)] +struct StateCapsule { + sm: StateMap, +} + +impl StateCapsule { + fn new() -> Self { + Self { + sm: StateMap::new(), + } + } +} + +impl Default for StateCapsule { + fn default() -> Self { + Self::new() + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn type_assertions() { + fn is_send(_t: &T) {} + fn is_clone(_t: &T) {} + + let this = StateCapsule::new(); + is_send(&this); + is_clone(&this); + } +} diff --git a/src/internal.rs b/src/internal.rs index 7299f90ed03..ccf6dbc82a0 100644 --- a/src/internal.rs +++ b/src/internal.rs @@ -2,3 +2,4 @@ pub(crate) mod get_slot; pub(crate) mod state; +pub(crate) mod typemap; diff --git a/src/internal/typemap.rs b/src/internal/typemap.rs new file mode 100644 index 00000000000..62a66ad8e8f --- /dev/null +++ b/src/internal/typemap.rs @@ -0,0 +1,461 @@ +#![allow(dead_code)] + +use std::{ + any::{Any, TypeId}, + boxed::Box, + collections::{hash_map, HashMap}, + hash::{BuildHasherDefault, Hasher}, + marker::PhantomData, +}; + +use self::downcast::Downcast; +#[allow(unused_imports)] +pub use self::downcast::{CloneAny, IntoBox}; + +mod downcast; + +/// Raw access to the underlying `HashMap`. +pub type RawMap = HashMap, BuildHasherDefault>; + +/// A keyed TypeMap, for storing disjointed sets of types. +/// +/// This collection inherits the performance characteristics +/// of the underlying `HashMap`, namely ~O(1) lookups, +/// inserts and fetches. +#[derive(Debug)] +pub struct TypeMap { + raw: RawMap, +} + +// #[derive(Clone)] would want Ty to implement Clone, but in +// reality only Box can. +impl Clone for TypeMap +where + Box: Clone, +{ + #[inline] + fn clone(&self) -> TypeMap { + TypeMap { + raw: self.raw.clone(), + } + } +} + +impl Default for TypeMap { + #[inline] + fn default() -> TypeMap { + TypeMap::new() + } +} + +impl TypeMap { + /// Create an empty collection. + #[inline] + pub fn new() -> TypeMap { + TypeMap { + raw: RawMap::with_hasher(Default::default()), + } + } + + /// Creates an empty collection with the given initial + /// capacity. + #[inline] + pub fn with_capacity(capacity: usize) -> TypeMap { + TypeMap { + raw: RawMap::with_capacity_and_hasher(capacity, Default::default()), + } + } + + /// Returns the number of elements the collection can + /// hold without reallocating. + #[inline] + pub fn capacity(&self) -> usize { + self.raw.capacity() + } + + /// Reserves capacity for at least `additional` more + /// elements to be inserted in the collection. The + /// collection may reserve more space to avoid + /// frequent reallocations. + /// + /// # Panics + /// + /// Panics if the new allocation size overflows `usize`. + #[inline] + pub fn reserve(&mut self, additional: usize) { + self.raw.reserve(additional) + } + + /// Shrinks the capacity of the collection as much as + /// possible. It will drop down as much as possible + /// while maintaining the internal rules + /// and possibly leaving some space in accordance with + /// the resize policy. + #[inline] + pub fn shrink_to_fit(&mut self) { + self.raw.shrink_to_fit() + } + + /// Returns the number of items in the collection. + #[inline] + pub fn len(&self) -> usize { + self.raw.len() + } + + /// Returns true if there are no items in the + /// collection. + #[inline] + pub fn is_empty(&self) -> bool { + self.raw.is_empty() + } + + /// Removes all items from the collection. Keeps the + /// allocated memory for reuse. + #[inline] + pub fn clear(&mut self) { + self.raw.clear() + } + + /// Returns a reference to the value stored in the + /// collection for the type `T`, if it exists. + #[inline] + pub fn get(&self) -> Option<&V> + where + V: IntoBox, + { + self.raw + .get(&Self::key_type_id::()) + .map(|ty| unsafe { ty.downcast_ref_unchecked::() }) + } + + /// Returns a mutable reference to the value stored in + /// the collection for the type `T`, if it exists. + #[inline] + pub fn get_mut(&mut self) -> Option<&mut V> + where + V: IntoBox, + { + self.raw + .get_mut(&Self::key_type_id::()) + .map(|ty| unsafe { ty.downcast_mut_unchecked::() }) + } + + /// Sets the value stored in the collection for the type + /// `T`. If the collection already had a value of + /// type `T`, that value is returned. Otherwise, + /// `None` is returned. + #[inline] + pub fn insert(&mut self, value: V) -> Option + where + V: IntoBox, + { + self.raw + .insert(Self::key_type_id::(), value.into_box()) + .map(|ty| unsafe { *ty.downcast_unchecked::() }) + } + + /// Removes the `T` value from the collection, + /// returning it if there was one or `None` if there was + /// not. + #[inline] + pub fn remove(&mut self) -> Option + where + V: IntoBox, + { + self.raw + .remove(&Self::key_type_id::()) + .map(|ty| *unsafe { ty.downcast_unchecked::() }) + } + + /// Returns true if the collection contains a value of + /// type `T`. + #[inline] + pub fn contains(&self) -> bool + where + V: IntoBox, + { + self.raw.contains_key(&Self::key_type_id::()) + } + + /// Gets the entry for the given type in the collection + /// for in-place manipulation + #[inline] + pub fn entry(&mut self) -> Entry<'_, Ty, V> + where + V: IntoBox, + { + match self.raw.entry(Self::key_type_id::()) { + hash_map::Entry::Occupied(e) => Entry::Occupied(OccupiedEntry { + inner: e, + type_: PhantomData, + }), + hash_map::Entry::Vacant(e) => Entry::Vacant(VacantEntry { + inner: e, + type_: PhantomData, + }), + } + } + + /// Get access to the raw hash map that backs this. + /// + /// This will seldom be useful, but it’s conceivable + /// that you could wish to iterate over all the + /// items in the collection, and this lets you do that. + #[inline] + pub fn as_raw(&self) -> &RawMap { + &self.raw + } + + /// Get mutable access to the raw hash map that backs + /// this. + /// + /// This will seldom be useful, but it’s conceivable + /// that you could wish to iterate over all the + /// items in the collection mutably, or drain or + /// something, or *possibly* even batch insert, and + /// this lets you do that. + /// + /// # Safety + /// + /// If you insert any values to the raw map, the key (a + /// `TypeId`) must match the value’s type, or + /// *undefined behaviour* will occur when you access + /// those values. + /// + /// (*Removing* entries is perfectly safe.) + #[inline] + pub unsafe fn as_raw_mut(&mut self) -> &mut RawMap { + &mut self.raw + } + + /// Convert this into the raw hash map that backs this. + /// + /// This will seldom be useful, but it’s conceivable + /// that you could wish to consume all the items in + /// the collection and do *something* with some or all + /// of them, and this lets you do that, without the + /// `unsafe` that `.as_raw_mut().drain()` would require. + #[inline] + pub fn into_raw(self) -> RawMap { + self.raw + } + + fn key_type_id() -> TypeMapKey + where + T: 'static, + { + TypeMapKey::with_typeid(TypeId::of::()) + } +} + +/// A view into a single occupied location in an `Map`. +pub struct OccupiedEntry<'a, Ty: ?Sized + Downcast, V: 'a> { + inner: hash_map::OccupiedEntry<'a, TypeMapKey, Box>, + type_: PhantomData, +} + +/// A view into a single empty location in an `Map`. +pub struct VacantEntry<'a, Ty: ?Sized + Downcast, V: 'a> { + inner: hash_map::VacantEntry<'a, TypeMapKey, Box>, + type_: PhantomData, +} + +/// A view into a single location in an `Map`, which may be +/// vacant or occupied. +pub enum Entry<'a, Ty: ?Sized + Downcast, V> { + /// An occupied Entry + Occupied(OccupiedEntry<'a, Ty, V>), + /// A vacant Entry + Vacant(VacantEntry<'a, Ty, V>), +} + +impl<'a, Ty: ?Sized + Downcast, V: IntoBox> Entry<'a, Ty, V> { + /// Ensures a value is in the entry by inserting the + /// default if empty, and returns + /// a mutable reference to the value in the entry. + #[inline] + pub fn or_insert(self, default: V) -> &'a mut V { + match self { + Entry::Occupied(inner) => inner.into_mut(), + Entry::Vacant(inner) => inner.insert(default), + } + } + + /// Ensures a value is in the entry by inserting the + /// result of the default function if empty, and + /// returns a mutable reference to the value in the + /// entry. + #[inline] + pub fn or_insert_with V>(self, default: F) -> &'a mut V { + match self { + Entry::Occupied(inner) => inner.into_mut(), + Entry::Vacant(inner) => inner.insert(default()), + } + } + + /// Ensures a value is in the entry by inserting the + /// default value if empty, and returns a mutable + /// reference to the value in the entry. + #[inline] + pub fn or_default(self) -> &'a mut V + where + V: Default, + { + match self { + Entry::Occupied(inner) => inner.into_mut(), + Entry::Vacant(inner) => inner.insert(Default::default()), + } + } + + /// Provides in-place mutable access to an occupied + /// entry before any potential inserts into the map. + #[inline] + pub fn and_modify(self, f: F) -> Self { + match self { + Entry::Occupied(mut inner) => { + f(inner.get_mut()); + Entry::Occupied(inner) + } + Entry::Vacant(inner) => Entry::Vacant(inner), + } + } + + // Additional stable methods (as of 1.60.0-nightly) that + // could be added: insert_entry(self, value: V) -> + // OccupiedEntry<'a, K, V> (1.59.0) +} + +impl<'a, Ty: ?Sized + Downcast, V: IntoBox> OccupiedEntry<'a, Ty, V> { + /// Gets a reference to the value in the entry + #[inline] + pub fn get(&self) -> &V { + unsafe { self.inner.get().downcast_ref_unchecked() } + } + + /// Gets a mutable reference to the value in the entry + #[inline] + pub fn get_mut(&mut self) -> &mut V { + unsafe { self.inner.get_mut().downcast_mut_unchecked() } + } + + /// Converts the OccupiedEntry into a mutable reference + /// to the value in the entry with a lifetime bound + /// to the collection itself + #[inline] + pub fn into_mut(self) -> &'a mut V { + unsafe { self.inner.into_mut().downcast_mut_unchecked() } + } + + /// Sets the value of the entry, and returns the entry's + /// old value + #[inline] + pub fn insert(&mut self, value: V) -> V { + unsafe { *self.inner.insert(value.into_box()).downcast_unchecked() } + } + + /// Takes the value out of the entry, and returns it + #[inline] + pub fn remove(self) -> V { + unsafe { *self.inner.remove().downcast_unchecked() } + } +} + +impl<'a, Ty: ?Sized + Downcast, V: IntoBox> VacantEntry<'a, Ty, V> { + /// Sets the value of the entry with the VacantEntry's + /// key, and returns a mutable reference to it + #[inline] + pub fn insert(self, value: V) -> &'a mut V { + unsafe { self.inner.insert(value.into_box()).downcast_mut_unchecked() } + } +} + +/// The map key for [`TypeMap`]. +/// +/// Typically, this can be considered an implementation +/// detail of the library, though if you're not +/// using [`tmkey!`] for deriving keys it may be useful. +/// +/// There are two variants of this type: +/// +/// 1. A [`TypeId`], probably of the relevant [`Key`] +/// 2. A prehashed value stored as a `u64` +/// +/// The second type allows for a limited form of runtime +/// dynamism, which the caller is responsible for ensuring +/// that the `u64` -> `T` pair is singular. +#[derive(Debug, Clone, Hash, PartialEq, Eq)] +pub struct TypeMapKey { + k: TypeKey, +} + +impl TypeMapKey { + /// Create a new [`TypeId`] variant + pub const fn with_typeid(id: TypeId) -> Self { + Self { k: TypeKey::Id(id) } + } + + /// Create a new externally hashed variant + pub const fn with_exthash(value: u64) -> Self { + Self { + k: TypeKey::ExtHash(value), + } + } +} + +impl From for TypeMapKey { + fn from(id: TypeId) -> Self { + Self::with_typeid(id) + } +} + +#[derive(Debug, Clone, PartialEq, Eq)] +enum TypeKey { + /// A normal [`TypeId`] + Id(TypeId), + /// A value that has already been hashed, externally + ExtHash(u64), +} + +impl std::hash::Hash for TypeKey { + fn hash(&self, state: &mut H) { + match self { + TypeKey::Id(id) => id.hash(state), + TypeKey::ExtHash(pre) => pre.hash(state), + } + } +} + +/// A hasher designed to eke a little more speed out, given +/// `TypeId`’s known characteristics. +/// +/// This hasher effectively acts as a noop Hash +/// implementation, as we implicitly trust Rust's `TypeId` +/// uniqueness guarantee. +#[derive(Default)] +pub struct TypeIdHasher { + value: u64, +} + +impl Hasher for TypeIdHasher { + #[inline] + fn write(&mut self, bytes: &[u8]) { + // This expects to receive exactly one 64-bit value, and + // there’s no realistic chance of that changing, but + // I don’t want to depend on something that isn’t expressly + // part of the contract for safety. But I’m OK with + // release builds putting everything in one bucket + // if it *did* change (and debug builds panicking). + debug_assert_eq!(bytes.len(), 8); + let _ = bytes + .try_into() + .map(|array| self.value = u64::from_ne_bytes(array)); + } + + #[inline] + fn finish(&self) -> u64 { + self.value + } +} + +#[cfg(test)] +mod tests; diff --git a/src/internal/typemap/downcast.rs b/src/internal/typemap/downcast.rs new file mode 100644 index 00000000000..1f8d3ccdb0f --- /dev/null +++ b/src/internal/typemap/downcast.rs @@ -0,0 +1,165 @@ +use core::{ + any::{Any, TypeId}, + fmt, +}; + +use macros::{make_clone, make_downcast}; + +/// Methods for downcasting from an `Any` trait object. +/// +/// These should only be implemented for types that satisfy: +/// 1. Implements `Any` (including transitively) +/// +/// This includes most types, *excluding* ones that have a +/// non static lifetime -- references, `Struct<'a>`'s, etc +pub trait Downcast { + /// Gets the `TypeId` of `self`. + /// + /// If you can't implement this via a naive call to + /// Self::type_id() you probably shouldn't implement + /// this trait for your type(s). + fn type_id(&self) -> TypeId; + + /// Downcast from `Box` to `Box`, without + /// checking the type matches. + /// + /// ## Safety + /// + /// The caller must ensure that `T` matches the trait + /// object, via external means. + unsafe fn downcast_unchecked(self: Box) -> Box; + + /// Downcast from `&Any` to `&T`, without checking the + /// type matches. + /// + /// ## Safety + /// + /// The caller must ensure that `T` matches the trait + /// object, via external means. + unsafe fn downcast_ref_unchecked(&self) -> &T; + + /// Downcast from `&mut Any` to `&mut T`, without + /// checking the type matches. + /// + /// ## Safety + /// + /// The caller must ensure that `T` matches the trait + /// object, via external means. + unsafe fn downcast_mut_unchecked(&mut self) -> &mut T; +} + +/// A generic conversion of a type to a dyn trait object +pub trait IntoBox: Any { + fn into_box(self) -> Box; +} + +/// [`Any`], but with cloning. +/// +/// Every type with no non-`'static` references that +/// implements `Clone` implements `CloneAny`. +/// See [`core::any`] for more details on `Any` in general. +pub trait CloneAny: Any + CloneToAny {} +impl CloneAny for T {} + +/// This trait is used for library internals, please ignore +#[doc(hidden)] +pub trait CloneToAny { + /// Clone `self` into a new `Box` object. + fn clone_to_any(&self) -> Box; +} + +impl CloneToAny for T { + #[inline] + fn clone_to_any(&self) -> Box { + Box::new(self.clone()) + } +} + +/* Any */ + +make_downcast!(Any); +make_downcast!(Any + Send); +make_downcast!(Any + Send + Sync); + +/* CloneAny */ + +make_downcast!(CloneAny); +make_downcast!(CloneAny + Send); +make_downcast!(CloneAny + Send + Sync); +make_clone!(dyn CloneAny); +make_clone!(dyn CloneAny + Send); +make_clone!(dyn CloneAny + Send + Sync); + +mod macros { + /// Implement `Downcast` for the given $trait + macro_rules! make_downcast { + ($any_trait:ident $(+ $auto_traits:ident)*) => { + impl Downcast for dyn $any_trait $(+ $auto_traits)* { + #[inline] + fn type_id(&self) -> TypeId { + self.type_id() + } + + #[inline] + unsafe fn downcast_ref_unchecked(&self) -> &T { + unsafe { &*(self as *const Self as *const T) } + } + + #[inline] + unsafe fn downcast_mut_unchecked(&mut self) -> &mut T { + unsafe { &mut *(self as *mut Self as *mut T) } + } + + #[inline] + unsafe fn downcast_unchecked(self: Box) -> Box { + unsafe { Box::from_raw(Box::into_raw(self) as *mut T) } + } + } + + impl IntoBox for T { + #[inline] + fn into_box(self) -> Box { + Box::new(self) + } + } + } + } + + /// Implement `Clone` for the given $type + /// + /// We also implement a naive `Debug` output that prints + /// the $type name + macro_rules! make_clone { + ($t:ty) => { + impl Clone for Box<$t> { + #[inline] + fn clone(&self) -> Box<$t> { + let clone: Box = (**self).clone_to_any(); + let raw: *mut dyn CloneAny = Box::into_raw(clone); + + // We can't do a normal ptr cast here as we get a lint about + // a future hard + // error, `ptr_cast_add_auto_to_object`. + // + // This issue doesn't apply here, because we don't have any + // conditional methods (`CloneAny` always and only + // requires `Any`). Alas, we still have to + // transmute(), however to avoid the pesky lint + // + // https://github.com/rust-lang/rust/issues/127323 + unsafe { Box::from_raw(std::mem::transmute::<*mut dyn CloneAny, *mut _>(raw)) } + } + } + + impl fmt::Debug for $t { + #[inline] + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.pad(stringify!($t)) + } + } + }; + } + + pub(super) use make_clone; + pub(super) use make_downcast; +} diff --git a/src/internal/typemap/tests.rs b/src/internal/typemap/tests.rs new file mode 100644 index 00000000000..bc2ad81e292 --- /dev/null +++ b/src/internal/typemap/tests.rs @@ -0,0 +1,203 @@ +use super::CloneAny; +use super::*; + +#[derive(Clone, Debug, PartialEq)] +struct A(i32); +#[derive(Clone, Debug, PartialEq)] +struct B(i32); +#[derive(Clone, Debug, PartialEq)] +struct C(i32); +#[derive(Clone, Debug, PartialEq)] +struct D(i32); +#[derive(Clone, Debug, PartialEq)] +struct E(i32); +#[derive(Clone, Debug, PartialEq)] +struct F(i32); +#[derive(Clone, Debug, PartialEq)] +struct J(i32); + +#[test] +fn test_default() { + let map: TypeMap = Default::default(); + assert_eq!(map.len(), 0); +} + +#[test] +fn test_expected_traits() { + fn assert_send() {} + fn assert_sync() {} + fn assert_clone() {} + fn assert_debug() {} + + assert_send::>(); + assert_send::>(); + assert_sync::>(); + assert_debug::>(); + assert_debug::>(); + assert_debug::>(); + + assert_send::>(); + assert_send::>(); + assert_sync::>(); + assert_debug::>(); + assert_debug::>(); + assert_debug::>(); + assert_clone::>(); + assert_clone::>(); + assert_clone::>(); +} + +#[test] +fn test_variants() { + /* dyn Any (+ variants) */ + + let mut tm: TypeMap = TypeMap::new(); + { + assert_eq!(tm.insert(A(10)), None); + assert_eq!(tm.insert(B(20)), None); + assert_eq!(tm.insert(C(30)), None); + assert_eq!(tm.insert(D(40)), None); + assert_eq!(tm.insert(E(50)), None); + assert_eq!(tm.insert(F(60)), None); + + // Existing key (insert) + match tm.entry::() { + Entry::Vacant(_) => unreachable!(), + Entry::Occupied(mut view) => { + assert_eq!(view.get(), &A(10)); + assert_eq!(view.insert(A(100)), A(10)); + } + } + assert_eq!(tm.get::().unwrap(), &A(100)); + assert_eq!(tm.len(), 6); + + // Existing key (update) + match tm.entry::() { + Entry::Vacant(_) => unreachable!(), + Entry::Occupied(mut view) => { + let v = view.get_mut(); + let new_v = B(v.0 * 10); + *v = new_v; + } + } + assert_eq!(tm.get::().unwrap(), &B(200)); + assert_eq!(tm.len(), 6); + + // Existing key (remove) + match tm.entry::() { + Entry::Vacant(_) => unreachable!(), + Entry::Occupied(view) => { + assert_eq!(view.remove(), C(30)); + } + } + assert_eq!(tm.get::(), None); + assert_eq!(tm.len(), 5); + + // Inexistent key (insert) + match tm.entry::() { + Entry::Occupied(_) => unreachable!(), + Entry::Vacant(view) => { + assert_eq!(*view.insert(J(1000)), J(1000)); + } + } + assert_eq!(tm.get::().unwrap(), &J(1000)); + assert_eq!(tm.len(), 6); + + // Entry.or_insert on existing key + tm.entry::().or_insert(B(71)).0 += 1; + assert_eq!(tm.get::().unwrap(), &B(201)); + assert_eq!(tm.len(), 6); + + // Entry.or_insert on nonexisting key + tm.entry::().or_insert(C(300)).0 += 1; + assert_eq!(tm.get::().unwrap(), &C(301)); + assert_eq!(tm.len(), 7); + } + + /* dyn CloneAny (+ variants) */ + + let mut tm: TypeMap = TypeMap::new(); + { + assert_eq!(tm.insert(A(10)), None); + assert_eq!(tm.insert(B(20)), None); + assert_eq!(tm.insert(C(30)), None); + assert_eq!(tm.insert(D(40)), None); + assert_eq!(tm.insert(E(50)), None); + assert_eq!(tm.insert(F(60)), None); + + // Existing key (insert) + match tm.entry::() { + Entry::Vacant(_) => unreachable!(), + Entry::Occupied(mut view) => { + assert_eq!(view.get(), &A(10)); + assert_eq!(view.insert(A(100)), A(10)); + } + } + assert_eq!(tm.get::().unwrap(), &A(100)); + assert_eq!(tm.len(), 6); + + // Existing key (update) + match tm.entry::() { + Entry::Vacant(_) => unreachable!(), + Entry::Occupied(mut view) => { + let v = view.get_mut(); + let new_v = B(v.0 * 10); + *v = new_v; + } + } + assert_eq!(tm.get::().unwrap(), &B(200)); + assert_eq!(tm.len(), 6); + + // Existing key (remove) + match tm.entry::() { + Entry::Vacant(_) => unreachable!(), + Entry::Occupied(view) => { + assert_eq!(view.remove(), C(30)); + } + } + assert_eq!(tm.get::(), None); + assert_eq!(tm.len(), 5); + + // Inexistent key (insert) + match tm.entry::() { + Entry::Occupied(_) => unreachable!(), + Entry::Vacant(view) => { + assert_eq!(*view.insert(J(1000)), J(1000)); + } + } + assert_eq!(tm.get::().unwrap(), &J(1000)); + assert_eq!(tm.len(), 6); + + // Entry.or_insert on existing key + tm.entry::().or_insert(B(71)).0 += 1; + assert_eq!(tm.get::().unwrap(), &B(201)); + assert_eq!(tm.len(), 6); + + // Entry.or_insert on nonexisting key + tm.entry::().or_insert(C(300)).0 += 1; + assert_eq!(tm.get::().unwrap(), &C(301)); + assert_eq!(tm.len(), 7); + } +} + +#[test] +fn test_clone() { + let mut tm: TypeMap = TypeMap::new(); + let _ = tm.insert(A(1)); + let _ = tm.insert(B(2)); + /* No C */ + let _ = tm.insert(D(3)); + let _ = tm.insert(E(4)); + let _ = tm.insert(F(5)); + let _ = tm.insert(J(6)); + let tm2 = tm.clone(); + + assert_eq!(tm2.len(), 6); + assert_eq!(tm2.get::(), Some(&A(1))); + assert_eq!(tm2.get::(), Some(&B(2))); + assert_eq!(tm2.get::(), None); + assert_eq!(tm2.get::(), Some(&D(3))); + assert_eq!(tm2.get::(), Some(&E(4))); + assert_eq!(tm2.get::(), Some(&F(5))); + assert_eq!(tm2.get::(), Some(&J(6))); +} diff --git a/src/types/module.rs b/src/types/module.rs index 795ec737eeb..99ef484c3c6 100644 --- a/src/types/module.rs +++ b/src/types/module.rs @@ -1,6 +1,7 @@ use crate::err::{PyErr, PyResult}; use crate::ffi_ptr_ext::FfiPtrExt; use crate::impl_::callback::IntoPyCallbackOutput; +use crate::impl_::pymodule_state::{ModuleState, ModuleStateType}; use crate::py_result_ext::PyResultExt; use crate::pyclass::PyClass; use crate::types::{ @@ -402,6 +403,27 @@ pub trait PyModuleMethods<'py>: crate::sealed::Sealed { /// /// This is a no-op on the GIL-enabled build. fn gil_used(&self, gil_used: bool) -> PyResult<()>; + + /// Get an immutable view into the per-module state associated with this + /// PyModule. + fn state_ref(&self) -> Option<&T> + where + T: ModuleStateType + 'static; + + /// Get a mutable view into the per-module state associated with this + /// PyModule. + fn state_mut(&mut self) -> Option<&mut T> + where + T: ModuleStateType + 'static; + + /// Get a mutable view into the per-module state associated with this + /// PyModule. + /// + /// Will initialize the state type with the given `f` if needed. + fn state_or_init(&mut self, f: F) -> &mut T + where + T: ModuleStateType + 'static, + F: FnOnce() -> T; } impl<'py> PyModuleMethods<'py> for Bound<'py, PyModule> { @@ -549,6 +571,33 @@ impl<'py> PyModuleMethods<'py> for Bound<'py, PyModule> { #[cfg(any(Py_LIMITED_API, not(Py_GIL_DISABLED)))] Ok(()) } + + fn state_ref(&self) -> Option<&T> + where + T: ModuleStateType + 'static, + { + ModuleState::from_bound(self).state_map_ref().get::() + } + + fn state_mut(&mut self) -> Option<&mut T> + where + T: ModuleStateType + 'static, + { + ModuleState::from_bound_mut(self) + .state_map_mut() + .get_mut::() + } + + fn state_or_init(&mut self, f: F) -> &mut T + where + T: ModuleStateType + 'static, + F: FnOnce() -> T, + { + ModuleState::from_bound_mut(self) + .state_map_mut() + .entry::() + .or_insert_with(f) + } } fn __all__(py: Python<'_>) -> &Bound<'_, PyString> {