@@ -49,6 +49,9 @@ PYBIND11_WARNING_DISABLE_MSVC(4127)
4949class dtype; // Forward declaration
5050class array ; // Forward declaration
5151
52+ template <typename >
53+ struct numpy_scalar ; // Forward declaration
54+
5255PYBIND11_NAMESPACE_BEGIN (detail)
5356
5457template <>
@@ -245,6 +248,21 @@ struct npy_api {
245248 NPY_UINT64_
246249 = platform_lookup<std::uint64_t , unsigned long , unsigned long long , unsigned int >(
247250 NPY_ULONG_, NPY_ULONGLONG_, NPY_UINT_),
251+ NPY_FLOAT32_ = platform_lookup<float , double , float , long double >(
252+ NPY_DOUBLE_, NPY_FLOAT_, NPY_LONGDOUBLE_),
253+ NPY_FLOAT64_ = platform_lookup<double , double , float , long double >(
254+ NPY_DOUBLE_, NPY_FLOAT_, NPY_LONGDOUBLE_),
255+ NPY_COMPLEX64_
256+ = platform_lookup<std::complex <float >,
257+ std::complex <double >,
258+ std::complex <float >,
259+ std::complex <long double >>(NPY_DOUBLE_, NPY_FLOAT_, NPY_LONGDOUBLE_),
260+ NPY_COMPLEX128_
261+ = platform_lookup<std::complex <double >,
262+ std::complex <double >,
263+ std::complex <float >,
264+ std::complex <long double >>(NPY_DOUBLE_, NPY_FLOAT_, NPY_LONGDOUBLE_),
265+ NPY_CHAR_ = std::is_signed<char >::value ? NPY_BYTE_ : NPY_UBYTE_,
248266 };
249267
250268 unsigned int PyArray_RUNTIME_VERSION_;
@@ -268,6 +286,7 @@ struct npy_api {
268286
269287 unsigned int (*PyArray_GetNDArrayCFeatureVersion_)();
270288 PyObject *(*PyArray_DescrFromType_)(int );
289+ PyObject *(*PyArray_TypeObjectFromType_)(int );
271290 PyObject *(*PyArray_NewFromDescr_)(PyTypeObject *,
272291 PyObject *,
273292 int ,
@@ -284,6 +303,8 @@ struct npy_api {
284303 PyTypeObject *PyVoidArrType_Type_;
285304 PyTypeObject *PyArrayDescr_Type_;
286305 PyObject *(*PyArray_DescrFromScalar_)(PyObject *);
306+ PyObject *(*PyArray_Scalar_)(void *, PyObject *, PyObject *);
307+ void (*PyArray_ScalarAsCtype_)(PyObject *, void *);
287308 PyObject *(*PyArray_FromAny_)(PyObject *, PyObject *, int , int , int , PyObject *);
288309 int (*PyArray_DescrConverter_)(PyObject *, PyObject **);
289310 bool (*PyArray_EquivTypes_)(PyObject *, PyObject *);
@@ -301,7 +322,10 @@ struct npy_api {
301322 API_PyArrayDescr_Type = 3 ,
302323 API_PyVoidArrType_Type = 39 ,
303324 API_PyArray_DescrFromType = 45 ,
325+ API_PyArray_TypeObjectFromType = 46 ,
304326 API_PyArray_DescrFromScalar = 57 ,
327+ API_PyArray_Scalar = 60 ,
328+ API_PyArray_ScalarAsCtype = 62 ,
305329 API_PyArray_FromAny = 69 ,
306330 API_PyArray_Resize = 80 ,
307331 // CopyInto was slot 82 and 50 was effectively an alias. NumPy 2 removed 82.
@@ -336,7 +360,10 @@ struct npy_api {
336360 DECL_NPY_API (PyVoidArrType_Type);
337361 DECL_NPY_API (PyArrayDescr_Type);
338362 DECL_NPY_API (PyArray_DescrFromType);
363+ DECL_NPY_API (PyArray_TypeObjectFromType);
339364 DECL_NPY_API (PyArray_DescrFromScalar);
365+ DECL_NPY_API (PyArray_Scalar);
366+ DECL_NPY_API (PyArray_ScalarAsCtype);
340367 DECL_NPY_API (PyArray_FromAny);
341368 DECL_NPY_API (PyArray_Resize);
342369 DECL_NPY_API (PyArray_CopyInto);
@@ -355,6 +382,83 @@ struct npy_api {
355382 }
356383};
357384
385+ template <typename T>
386+ struct is_complex : std::false_type {};
387+ template <typename T>
388+ struct is_complex <std::complex <T>> : std::true_type {};
389+
390+ template <typename T, typename = void >
391+ struct npy_format_descriptor_name ;
392+
393+ template <typename T>
394+ struct npy_format_descriptor_name <T, enable_if_t <std::is_integral<T>::value>> {
395+ static constexpr auto name = const_name<std::is_same<T, bool >::value>(
396+ const_name (" numpy.bool" ),
397+ const_name<std::is_signed<T>::value>(" numpy.int" , " numpy.uint" )
398+ + const_name<sizeof (T) * 8 >());
399+ };
400+
401+ template <typename T>
402+ struct npy_format_descriptor_name <T, enable_if_t <std::is_floating_point<T>::value>> {
403+ static constexpr auto name = const_name < std::is_same<T, float >::value
404+ || std::is_same<T, const float >::value
405+ || std::is_same<T, double >::value
406+ || std::is_same<T, const double >::value
407+ > (const_name(" numpy.float" ) + const_name<sizeof (T) * 8 >(),
408+ const_name (" numpy.longdouble" ));
409+ };
410+
411+ template <typename T>
412+ struct npy_format_descriptor_name <T, enable_if_t <is_complex<T>::value>> {
413+ static constexpr auto name = const_name < std::is_same<typename T::value_type, float >::value
414+ || std::is_same<typename T::value_type, const float >::value
415+ || std::is_same<typename T::value_type, double >::value
416+ || std::is_same<typename T::value_type, const double >::value
417+ > (const_name(" numpy.complex" )
418+ + const_name<sizeof (typename T::value_type) * 16 >(),
419+ const_name (" numpy.longcomplex" ));
420+ };
421+
422+ template <typename T>
423+ struct numpy_scalar_info {};
424+
425+ #define PYBIND11_NUMPY_SCALAR_IMPL (ctype_, typenum_ ) \
426+ template <> \
427+ struct numpy_scalar_info <ctype_> { \
428+ static constexpr auto name = npy_format_descriptor_name<ctype_>::name; \
429+ static constexpr int typenum = npy_api::typenum_##_; \
430+ }
431+
432+ // boolean type
433+ PYBIND11_NUMPY_SCALAR_IMPL (bool , NPY_BOOL);
434+
435+ // character types
436+ PYBIND11_NUMPY_SCALAR_IMPL (char , NPY_CHAR);
437+ PYBIND11_NUMPY_SCALAR_IMPL (signed char , NPY_BYTE);
438+ PYBIND11_NUMPY_SCALAR_IMPL (unsigned char , NPY_UBYTE);
439+
440+ // signed integer types
441+ PYBIND11_NUMPY_SCALAR_IMPL (std::int16_t , NPY_INT16);
442+ PYBIND11_NUMPY_SCALAR_IMPL (std::int32_t , NPY_INT32);
443+ PYBIND11_NUMPY_SCALAR_IMPL (std::int64_t , NPY_INT64);
444+
445+ // unsigned integer types
446+ PYBIND11_NUMPY_SCALAR_IMPL (std::uint16_t , NPY_UINT16);
447+ PYBIND11_NUMPY_SCALAR_IMPL (std::uint32_t , NPY_UINT32);
448+ PYBIND11_NUMPY_SCALAR_IMPL (std::uint64_t , NPY_UINT64);
449+
450+ // floating point types
451+ PYBIND11_NUMPY_SCALAR_IMPL (float , NPY_FLOAT);
452+ PYBIND11_NUMPY_SCALAR_IMPL (double , NPY_DOUBLE);
453+ PYBIND11_NUMPY_SCALAR_IMPL (long double , NPY_LONGDOUBLE);
454+
455+ // complex types
456+ PYBIND11_NUMPY_SCALAR_IMPL (std::complex <float >, NPY_CFLOAT);
457+ PYBIND11_NUMPY_SCALAR_IMPL (std::complex <double >, NPY_CDOUBLE);
458+ PYBIND11_NUMPY_SCALAR_IMPL (std::complex <long double >, NPY_CLONGDOUBLE);
459+
460+ #undef PYBIND11_NUMPY_SCALAR_IMPL
461+
358462// This table normalizes typenums by mapping NPY_INT_, NPY_LONG, ... to NPY_INT32_, NPY_INT64, ...
359463// This is needed to correctly handle situations where multiple typenums map to the same type,
360464// e.g. NPY_LONG_ may be equivalent to NPY_INT_ or NPY_LONGLONG_ despite having a different
@@ -453,10 +557,6 @@ template <typename T>
453557struct is_std_array : std::false_type {};
454558template <typename T, size_t N>
455559struct is_std_array <std::array<T, N>> : std::true_type {};
456- template <typename T>
457- struct is_complex : std::false_type {};
458- template <typename T>
459- struct is_complex <std::complex <T>> : std::true_type {};
460560
461561template <typename T>
462562struct array_info_scalar {
@@ -670,8 +770,65 @@ template <typename T, ssize_t Dim>
670770struct type_caster <unchecked_mutable_reference<T, Dim>>
671771 : type_caster<unchecked_reference<T, Dim>> {};
672772
773+ template <typename T>
774+ struct type_caster <numpy_scalar<T>> {
775+ using value_type = T;
776+ using type_info = numpy_scalar_info<T>;
777+
778+ PYBIND11_TYPE_CASTER (numpy_scalar<T>, type_info::name);
779+
780+ static handle &target_type () {
781+ static handle tp = npy_api::get ().PyArray_TypeObjectFromType_ (type_info::typenum);
782+ return tp;
783+ }
784+
785+ static handle &target_dtype () {
786+ static handle tp = npy_api::get ().PyArray_DescrFromType_ (type_info::typenum);
787+ return tp;
788+ }
789+
790+ bool load (handle src, bool ) {
791+ if (isinstance (src, target_type ())) {
792+ npy_api::get ().PyArray_ScalarAsCtype_ (src.ptr (), &value.value );
793+ return true ;
794+ }
795+ return false ;
796+ }
797+
798+ static handle cast (numpy_scalar<T> src, return_value_policy, handle) {
799+ return npy_api::get ().PyArray_Scalar_ (&src.value , target_dtype ().ptr (), nullptr );
800+ }
801+ };
802+
673803PYBIND11_NAMESPACE_END (detail)
674804
805+ template <typename T>
806+ struct numpy_scalar {
807+ using value_type = T;
808+
809+ value_type value;
810+
811+ numpy_scalar () = default ;
812+ explicit numpy_scalar (value_type value) : value (value) {}
813+
814+ explicit operator value_type () const { return value; }
815+ numpy_scalar &operator =(value_type value) {
816+ this ->value = value;
817+ return *this ;
818+ }
819+
820+ friend bool operator ==(const numpy_scalar &a, const numpy_scalar &b) {
821+ return a.value == b.value ;
822+ }
823+
824+ friend bool operator !=(const numpy_scalar &a, const numpy_scalar &b) { return !(a == b); }
825+ };
826+
827+ template <typename T>
828+ numpy_scalar<T> make_scalar (T value) {
829+ return numpy_scalar<T>(value);
830+ }
831+
675832class dtype : public object {
676833public:
677834 PYBIND11_OBJECT_DEFAULT (dtype, object, detail::npy_api::get().PyArrayDescr_Check_)
@@ -1409,38 +1566,6 @@ struct compare_buffer_info<T, detail::enable_if_t<detail::is_pod_struct<T>::valu
14091566 }
14101567};
14111568
1412- template <typename T, typename = void >
1413- struct npy_format_descriptor_name ;
1414-
1415- template <typename T>
1416- struct npy_format_descriptor_name <T, enable_if_t <std::is_integral<T>::value>> {
1417- static constexpr auto name = const_name<std::is_same<T, bool >::value>(
1418- const_name (" bool" ),
1419- const_name<std::is_signed<T>::value>(" numpy.int" , " numpy.uint" )
1420- + const_name<sizeof (T) * 8 >());
1421- };
1422-
1423- template <typename T>
1424- struct npy_format_descriptor_name <T, enable_if_t <std::is_floating_point<T>::value>> {
1425- static constexpr auto name = const_name < std::is_same<T, float >::value
1426- || std::is_same<T, const float >::value
1427- || std::is_same<T, double >::value
1428- || std::is_same<T, const double >::value
1429- > (const_name(" numpy.float" ) + const_name<sizeof (T) * 8 >(),
1430- const_name (" numpy.longdouble" ));
1431- };
1432-
1433- template <typename T>
1434- struct npy_format_descriptor_name <T, enable_if_t <is_complex<T>::value>> {
1435- static constexpr auto name = const_name < std::is_same<typename T::value_type, float >::value
1436- || std::is_same<typename T::value_type, const float >::value
1437- || std::is_same<typename T::value_type, double >::value
1438- || std::is_same<typename T::value_type, const double >::value
1439- > (const_name(" numpy.complex" )
1440- + const_name<sizeof (typename T::value_type) * 16 >(),
1441- const_name (" numpy.longcomplex" ));
1442- };
1443-
14441569template <typename T>
14451570struct npy_format_descriptor <
14461571 T,
0 commit comments