5050//!
5151//! [datetime]: https://numpy.org/doc/stable/reference/arrays.datetime.html
5252
53+ use std:: cell:: UnsafeCell ;
54+ use std:: collections:: hash_map:: Entry ;
5355use std:: fmt;
5456use std:: hash:: Hash ;
5557use std:: marker:: PhantomData ;
5658
57- use pyo3:: Python ;
59+ use ahash:: AHashMap ;
60+ use pyo3:: { Py , Python } ;
5861
5962use crate :: dtype:: { Element , PyArrayDescr } ;
6063use crate :: npyffi:: { PyArray_DatetimeDTypeMetaData , NPY_DATETIMEUNIT , NPY_TYPES } ;
@@ -147,14 +150,9 @@ unsafe impl<U: Unit> Element for Datetime<U> {
147150 const IS_COPY : bool = true ;
148151
149152 fn get_dtype ( py : Python ) -> & PyArrayDescr {
150- // FIXME(adamreichold): Memoize these via the Unit trait
151- let dtype = PyArrayDescr :: new_from_npy_type ( py, NPY_TYPES :: NPY_DATETIME ) ;
153+ static DTYPES : TypeDescriptors = unsafe { TypeDescriptors :: new ( NPY_TYPES :: NPY_DATETIME ) } ;
152154
153- unsafe {
154- set_unit ( dtype, U :: UNIT ) ;
155- }
156-
157- dtype
155+ DTYPES . from_unit ( py, U :: UNIT )
158156 }
159157}
160158
@@ -187,14 +185,9 @@ unsafe impl<U: Unit> Element for Timedelta<U> {
187185 const IS_COPY : bool = true ;
188186
189187 fn get_dtype ( py : Python ) -> & PyArrayDescr {
190- // FIXME(adamreichold): Memoize these via the Unit trait
191- let dtype = PyArrayDescr :: new_from_npy_type ( py, NPY_TYPES :: NPY_TIMEDELTA ) ;
188+ static DTYPES : TypeDescriptors = unsafe { TypeDescriptors :: new ( NPY_TYPES :: NPY_TIMEDELTA ) } ;
192189
193- unsafe {
194- set_unit ( dtype, U :: UNIT ) ;
195- }
196-
197- dtype
190+ DTYPES . from_unit ( py, U :: UNIT )
198191 }
199192}
200193
@@ -204,11 +197,50 @@ impl<U: Unit> fmt::Debug for Timedelta<U> {
204197 }
205198}
206199
207- unsafe fn set_unit ( dtype : & PyArrayDescr , unit : NPY_DATETIMEUNIT ) {
208- let metadata = & mut * ( ( * dtype. as_dtype_ptr ( ) ) . c_metadata as * mut PyArray_DatetimeDTypeMetaData ) ;
200+ struct TypeDescriptors {
201+ npy_type : NPY_TYPES ,
202+ dtypes : UnsafeCell < Option < AHashMap < NPY_DATETIMEUNIT , Py < PyArrayDescr > > > > ,
203+ }
204+
205+ unsafe impl Sync for TypeDescriptors { }
206+
207+ impl TypeDescriptors {
208+ /// `npy_type` must be either `NPY_DATETIME` or `NPY_TIMEDELTA`.
209+ const unsafe fn new ( npy_type : NPY_TYPES ) -> Self {
210+ Self {
211+ npy_type,
212+ dtypes : UnsafeCell :: new ( None ) ,
213+ }
214+ }
209215
210- metadata. meta . base = unit;
211- metadata. meta . num = 1 ;
216+ #[ allow( clippy:: mut_from_ref) ]
217+ unsafe fn get ( & self ) -> & mut AHashMap < NPY_DATETIMEUNIT , Py < PyArrayDescr > > {
218+ ( * self . dtypes . get ( ) ) . get_or_insert_with ( AHashMap :: new)
219+ }
220+
221+ #[ allow( clippy:: wrong_self_convention) ]
222+ fn from_unit < ' py > ( & ' py self , py : Python < ' py > , unit : NPY_DATETIMEUNIT ) -> & ' py PyArrayDescr {
223+ // SAFETY: We hold the GIL and we do not call into user code which might re-enter.
224+ let dtypes = unsafe { self . get ( ) } ;
225+
226+ match dtypes. entry ( unit) {
227+ Entry :: Occupied ( entry) => entry. into_mut ( ) . as_ref ( py) ,
228+ Entry :: Vacant ( entry) => {
229+ let dtype = PyArrayDescr :: new_from_npy_type ( py, self . npy_type ) ;
230+
231+ // SAFETY: `self.npy_type` is either `NPY_DATETIME` or `NPY_TIMEDELTA` which implies the type of `c_metadata`.
232+ unsafe {
233+ let metadata = & mut * ( ( * dtype. as_dtype_ptr ( ) ) . c_metadata
234+ as * mut PyArray_DatetimeDTypeMetaData ) ;
235+
236+ metadata. meta . base = unit;
237+ metadata. meta . num = 1 ;
238+ }
239+
240+ entry. insert ( dtype. into ( ) ) . as_ref ( py)
241+ }
242+ }
243+ }
212244}
213245
214246#[ cfg( test) ]
0 commit comments