Skip to content

Commit f41dae3

Browse files
authored
Add dtype::normalized_num and dtype::num_of (#5429)
* Add dtype::normalized_num and dtype::num_of * Fix compiler warning and improve NumPy 1.x compatibility * Fix clang-tidy warning * Fix another clang-tidy warning * Add extra comment
1 parent b9fb316 commit f41dae3

File tree

3 files changed

+216
-1
lines changed

3 files changed

+216
-1
lines changed

include/pybind11/numpy.h

Lines changed: 90 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,7 @@ constexpr int platform_lookup(int I, Ints... Is) {
212212
}
213213

214214
struct npy_api {
215+
// If you change this code, please review `normalized_dtype_num` below.
215216
enum constants {
216217
NPY_ARRAY_C_CONTIGUOUS_ = 0x0001,
217218
NPY_ARRAY_F_CONTIGUOUS_ = 0x0002,
@@ -384,6 +385,74 @@ struct npy_api {
384385
}
385386
};
386387

388+
// This table normalizes typenums by mapping NPY_INT_, NPY_LONG, ... to NPY_INT32_, NPY_INT64, ...
389+
// This is needed to correctly handle situations where multiple typenums map to the same type,
390+
// e.g. NPY_LONG_ may be equivalent to NPY_INT_ or NPY_LONGLONG_ despite having a different
391+
// typenum. The normalized typenum should always match the values used in npy_format_descriptor.
392+
// If you change this code, please review `enum constants` above.
393+
static constexpr int normalized_dtype_num[npy_api::NPY_VOID_ + 1] = {
394+
// NPY_BOOL_ =>
395+
npy_api::NPY_BOOL_,
396+
// NPY_BYTE_ =>
397+
npy_api::NPY_BYTE_,
398+
// NPY_UBYTE_ =>
399+
npy_api::NPY_UBYTE_,
400+
// NPY_SHORT_ =>
401+
npy_api::NPY_INT16_,
402+
// NPY_USHORT_ =>
403+
npy_api::NPY_UINT16_,
404+
// NPY_INT_ =>
405+
sizeof(int) == sizeof(std::int16_t) ? npy_api::NPY_INT16_
406+
: sizeof(int) == sizeof(std::int32_t) ? npy_api::NPY_INT32_
407+
: sizeof(int) == sizeof(std::int64_t) ? npy_api::NPY_INT64_
408+
: npy_api::NPY_INT_,
409+
// NPY_UINT_ =>
410+
sizeof(unsigned int) == sizeof(std::uint16_t) ? npy_api::NPY_UINT16_
411+
: sizeof(unsigned int) == sizeof(std::uint32_t) ? npy_api::NPY_UINT32_
412+
: sizeof(unsigned int) == sizeof(std::uint64_t) ? npy_api::NPY_UINT64_
413+
: npy_api::NPY_UINT_,
414+
// NPY_LONG_ =>
415+
sizeof(long) == sizeof(std::int16_t) ? npy_api::NPY_INT16_
416+
: sizeof(long) == sizeof(std::int32_t) ? npy_api::NPY_INT32_
417+
: sizeof(long) == sizeof(std::int64_t) ? npy_api::NPY_INT64_
418+
: npy_api::NPY_LONG_,
419+
// NPY_ULONG_ =>
420+
sizeof(unsigned long) == sizeof(std::uint16_t) ? npy_api::NPY_UINT16_
421+
: sizeof(unsigned long) == sizeof(std::uint32_t) ? npy_api::NPY_UINT32_
422+
: sizeof(unsigned long) == sizeof(std::uint64_t) ? npy_api::NPY_UINT64_
423+
: npy_api::NPY_ULONG_,
424+
// NPY_LONGLONG_ =>
425+
sizeof(long long) == sizeof(std::int16_t) ? npy_api::NPY_INT16_
426+
: sizeof(long long) == sizeof(std::int32_t) ? npy_api::NPY_INT32_
427+
: sizeof(long long) == sizeof(std::int64_t) ? npy_api::NPY_INT64_
428+
: npy_api::NPY_LONGLONG_,
429+
// NPY_ULONGLONG_ =>
430+
sizeof(unsigned long long) == sizeof(std::uint16_t) ? npy_api::NPY_UINT16_
431+
: sizeof(unsigned long long) == sizeof(std::uint32_t) ? npy_api::NPY_UINT32_
432+
: sizeof(unsigned long long) == sizeof(std::uint64_t) ? npy_api::NPY_UINT64_
433+
: npy_api::NPY_ULONGLONG_,
434+
// NPY_FLOAT_ =>
435+
npy_api::NPY_FLOAT_,
436+
// NPY_DOUBLE_ =>
437+
npy_api::NPY_DOUBLE_,
438+
// NPY_LONGDOUBLE_ =>
439+
npy_api::NPY_LONGDOUBLE_,
440+
// NPY_CFLOAT_ =>
441+
npy_api::NPY_CFLOAT_,
442+
// NPY_CDOUBLE_ =>
443+
npy_api::NPY_CDOUBLE_,
444+
// NPY_CLONGDOUBLE_ =>
445+
npy_api::NPY_CLONGDOUBLE_,
446+
// NPY_OBJECT_ =>
447+
npy_api::NPY_OBJECT_,
448+
// NPY_STRING_ =>
449+
npy_api::NPY_STRING_,
450+
// NPY_UNICODE_ =>
451+
npy_api::NPY_UNICODE_,
452+
// NPY_VOID_ =>
453+
npy_api::NPY_VOID_,
454+
};
455+
387456
inline PyArray_Proxy *array_proxy(void *ptr) { return reinterpret_cast<PyArray_Proxy *>(ptr); }
388457

389458
inline const PyArray_Proxy *array_proxy(const void *ptr) {
@@ -684,6 +753,13 @@ class dtype : public object {
684753
return detail::npy_format_descriptor<typename std::remove_cv<T>::type>::dtype();
685754
}
686755

756+
/// Return the type number associated with a C++ type.
757+
/// This is the constexpr equivalent of `dtype::of<T>().num()`.
758+
template <typename T>
759+
static constexpr int num_of() {
760+
return detail::npy_format_descriptor<typename std::remove_cv<T>::type>::value;
761+
}
762+
687763
/// Size of the data type in bytes.
688764
#ifdef PYBIND11_NUMPY_1_ONLY
689765
ssize_t itemsize() const { return detail::array_descriptor_proxy(m_ptr)->elsize; }
@@ -725,14 +801,27 @@ class dtype : public object {
725801
return detail::array_descriptor_proxy(m_ptr)->type;
726802
}
727803

728-
/// type number of dtype.
804+
/// Type number of dtype. Note that different values may be returned for equivalent types,
805+
/// e.g. even though ``long`` may be equivalent to ``int`` or ``long long``, they still have
806+
/// different type numbers. Consider using `normalized_num` to avoid this.
729807
int num() const {
730808
// Note: The signature, `dtype::num` follows the naming of NumPy's public
731809
// Python API (i.e., ``dtype.num``), rather than its internal
732810
// C API (``PyArray_Descr::type_num``).
733811
return detail::array_descriptor_proxy(m_ptr)->type_num;
734812
}
735813

814+
/// Type number of dtype, normalized to match the return value of `num_of` for equivalent
815+
/// types. This function can be used to write switch statements that correctly handle
816+
/// equivalent types with different type numbers.
817+
int normalized_num() const {
818+
int value = num();
819+
if (value >= 0 && value <= detail::npy_api::NPY_VOID_) {
820+
return detail::normalized_dtype_num[value];
821+
}
822+
return value;
823+
}
824+
736825
/// Single character for byteorder
737826
char byteorder() const { return detail::array_descriptor_proxy(m_ptr)->byteorder; }
738827

tests/test_numpy_dtypes.cpp

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,9 @@
1111

1212
#include "pybind11_tests.h"
1313

14+
#include <cstdint>
15+
#include <stdexcept>
16+
1417
#ifdef __GNUC__
1518
# define PYBIND11_PACKED(cls) cls __attribute__((__packed__))
1619
#else
@@ -297,6 +300,15 @@ py::list test_dtype_ctors() {
297300
return list;
298301
}
299302

303+
template <typename T>
304+
py::array_t<T> dispatch_array_increment(py::array_t<T> arr) {
305+
py::array_t<T> res(arr.shape(0));
306+
for (py::ssize_t i = 0; i < arr.shape(0); ++i) {
307+
res.mutable_at(i) = T(arr.at(i) + 1);
308+
}
309+
return res;
310+
}
311+
300312
struct A {};
301313
struct B {};
302314

@@ -496,6 +508,98 @@ TEST_SUBMODULE(numpy_dtypes, m) {
496508
}
497509
return list;
498510
});
511+
m.def("test_dtype_num_of", []() -> py::list {
512+
py::list res;
513+
#define TEST_DTYPE(T) res.append(py::make_tuple(py::dtype::of<T>().num(), py::dtype::num_of<T>()));
514+
TEST_DTYPE(bool)
515+
TEST_DTYPE(char)
516+
TEST_DTYPE(unsigned char)
517+
TEST_DTYPE(short)
518+
TEST_DTYPE(unsigned short)
519+
TEST_DTYPE(int)
520+
TEST_DTYPE(unsigned int)
521+
TEST_DTYPE(long)
522+
TEST_DTYPE(unsigned long)
523+
TEST_DTYPE(long long)
524+
TEST_DTYPE(unsigned long long)
525+
TEST_DTYPE(float)
526+
TEST_DTYPE(double)
527+
TEST_DTYPE(long double)
528+
TEST_DTYPE(std::complex<float>)
529+
TEST_DTYPE(std::complex<double>)
530+
TEST_DTYPE(std::complex<long double>)
531+
TEST_DTYPE(int8_t)
532+
TEST_DTYPE(uint8_t)
533+
TEST_DTYPE(int16_t)
534+
TEST_DTYPE(uint16_t)
535+
TEST_DTYPE(int32_t)
536+
TEST_DTYPE(uint32_t)
537+
TEST_DTYPE(int64_t)
538+
TEST_DTYPE(uint64_t)
539+
#undef TEST_DTYPE
540+
return res;
541+
});
542+
m.def("test_dtype_normalized_num", []() -> py::list {
543+
py::list res;
544+
#define TEST_DTYPE(NT, T) \
545+
res.append(py::make_tuple(py::dtype(py::detail::npy_api::NT).normalized_num(), \
546+
py::dtype::num_of<T>()));
547+
TEST_DTYPE(NPY_BOOL_, bool)
548+
TEST_DTYPE(NPY_BYTE_, char);
549+
TEST_DTYPE(NPY_UBYTE_, unsigned char);
550+
TEST_DTYPE(NPY_SHORT_, short);
551+
TEST_DTYPE(NPY_USHORT_, unsigned short);
552+
TEST_DTYPE(NPY_INT_, int);
553+
TEST_DTYPE(NPY_UINT_, unsigned int);
554+
TEST_DTYPE(NPY_LONG_, long);
555+
TEST_DTYPE(NPY_ULONG_, unsigned long);
556+
TEST_DTYPE(NPY_LONGLONG_, long long);
557+
TEST_DTYPE(NPY_ULONGLONG_, unsigned long long);
558+
TEST_DTYPE(NPY_FLOAT_, float);
559+
TEST_DTYPE(NPY_DOUBLE_, double);
560+
TEST_DTYPE(NPY_LONGDOUBLE_, long double);
561+
TEST_DTYPE(NPY_CFLOAT_, std::complex<float>);
562+
TEST_DTYPE(NPY_CDOUBLE_, std::complex<double>);
563+
TEST_DTYPE(NPY_CLONGDOUBLE_, std::complex<long double>);
564+
TEST_DTYPE(NPY_INT8_, int8_t);
565+
TEST_DTYPE(NPY_UINT8_, uint8_t);
566+
TEST_DTYPE(NPY_INT16_, int16_t);
567+
TEST_DTYPE(NPY_UINT16_, uint16_t);
568+
TEST_DTYPE(NPY_INT32_, int32_t);
569+
TEST_DTYPE(NPY_UINT32_, uint32_t);
570+
TEST_DTYPE(NPY_INT64_, int64_t);
571+
TEST_DTYPE(NPY_UINT64_, uint64_t);
572+
#undef TEST_DTYPE
573+
return res;
574+
});
575+
m.def("test_dtype_switch", [](const py::array &arr) -> py::array {
576+
switch (arr.dtype().normalized_num()) {
577+
case py::dtype::num_of<int8_t>():
578+
return dispatch_array_increment<int8_t>(arr);
579+
case py::dtype::num_of<uint8_t>():
580+
return dispatch_array_increment<uint8_t>(arr);
581+
case py::dtype::num_of<int16_t>():
582+
return dispatch_array_increment<int16_t>(arr);
583+
case py::dtype::num_of<uint16_t>():
584+
return dispatch_array_increment<uint16_t>(arr);
585+
case py::dtype::num_of<int32_t>():
586+
return dispatch_array_increment<int32_t>(arr);
587+
case py::dtype::num_of<uint32_t>():
588+
return dispatch_array_increment<uint32_t>(arr);
589+
case py::dtype::num_of<int64_t>():
590+
return dispatch_array_increment<int64_t>(arr);
591+
case py::dtype::num_of<uint64_t>():
592+
return dispatch_array_increment<uint64_t>(arr);
593+
case py::dtype::num_of<float>():
594+
return dispatch_array_increment<float>(arr);
595+
case py::dtype::num_of<double>():
596+
return dispatch_array_increment<double>(arr);
597+
case py::dtype::num_of<long double>():
598+
return dispatch_array_increment<long double>(arr);
599+
default:
600+
throw std::runtime_error("Unsupported dtype");
601+
}
602+
});
499603
m.def("test_dtype_methods", []() {
500604
py::list list;
501605
auto dt1 = py::dtype::of<int32_t>();

tests/test_numpy_dtypes.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,28 @@ def test_dtype(simple_dtype):
188188
chr(np.dtype(ch).flags) for ch in expected_chars
189189
]
190190

191+
for a, b in m.test_dtype_num_of():
192+
assert a == b
193+
194+
for a, b in m.test_dtype_normalized_num():
195+
assert a == b
196+
197+
arr = np.array([4, 84, 21, 36])
198+
# Note: "ulong" does not work in NumPy 1.x, so we use "L"
199+
assert (m.test_dtype_switch(arr.astype("byte")) == arr + 1).all()
200+
assert (m.test_dtype_switch(arr.astype("ubyte")) == arr + 1).all()
201+
assert (m.test_dtype_switch(arr.astype("short")) == arr + 1).all()
202+
assert (m.test_dtype_switch(arr.astype("ushort")) == arr + 1).all()
203+
assert (m.test_dtype_switch(arr.astype("intc")) == arr + 1).all()
204+
assert (m.test_dtype_switch(arr.astype("uintc")) == arr + 1).all()
205+
assert (m.test_dtype_switch(arr.astype("long")) == arr + 1).all()
206+
assert (m.test_dtype_switch(arr.astype("L")) == arr + 1).all()
207+
assert (m.test_dtype_switch(arr.astype("longlong")) == arr + 1).all()
208+
assert (m.test_dtype_switch(arr.astype("ulonglong")) == arr + 1).all()
209+
assert (m.test_dtype_switch(arr.astype("single")) == arr + 1).all()
210+
assert (m.test_dtype_switch(arr.astype("double")) == arr + 1).all()
211+
assert (m.test_dtype_switch(arr.astype("longdouble")) == arr + 1).all()
212+
191213

192214
def test_recarray(simple_dtype, packed_dtype):
193215
elements = [(False, 0, 0.0, -0.0), (True, 1, 1.5, -2.5), (False, 2, 3.0, -5.0)]

0 commit comments

Comments
 (0)