|
11 | 11 |
|
12 | 12 | #include "pybind11_tests.h" |
13 | 13 |
|
| 14 | +#include <cstdint> |
| 15 | +#include <stdexcept> |
| 16 | + |
14 | 17 | #ifdef __GNUC__ |
15 | 18 | # define PYBIND11_PACKED(cls) cls __attribute__((__packed__)) |
16 | 19 | #else |
@@ -297,6 +300,15 @@ py::list test_dtype_ctors() { |
297 | 300 | return list; |
298 | 301 | } |
299 | 302 |
|
| 303 | +template <typename T> |
| 304 | +py::array_t<T> dispatch_array_increment(py::array_t<T> arr) { |
| 305 | + py::array_t<T> res(arr.shape(0)); |
| 306 | + for (py::ssize_t i = 0; i < arr.shape(0); ++i) { |
| 307 | + res.mutable_at(i) = T(arr.at(i) + 1); |
| 308 | + } |
| 309 | + return res; |
| 310 | +} |
| 311 | + |
300 | 312 | struct A {}; |
301 | 313 | struct B {}; |
302 | 314 |
|
@@ -496,6 +508,98 @@ TEST_SUBMODULE(numpy_dtypes, m) { |
496 | 508 | } |
497 | 509 | return list; |
498 | 510 | }); |
| 511 | + m.def("test_dtype_num_of", []() -> py::list { |
| 512 | + py::list res; |
| 513 | +#define TEST_DTYPE(T) res.append(py::make_tuple(py::dtype::of<T>().num(), py::dtype::num_of<T>())); |
| 514 | + TEST_DTYPE(bool) |
| 515 | + TEST_DTYPE(char) |
| 516 | + TEST_DTYPE(unsigned char) |
| 517 | + TEST_DTYPE(short) |
| 518 | + TEST_DTYPE(unsigned short) |
| 519 | + TEST_DTYPE(int) |
| 520 | + TEST_DTYPE(unsigned int) |
| 521 | + TEST_DTYPE(long) |
| 522 | + TEST_DTYPE(unsigned long) |
| 523 | + TEST_DTYPE(long long) |
| 524 | + TEST_DTYPE(unsigned long long) |
| 525 | + TEST_DTYPE(float) |
| 526 | + TEST_DTYPE(double) |
| 527 | + TEST_DTYPE(long double) |
| 528 | + TEST_DTYPE(std::complex<float>) |
| 529 | + TEST_DTYPE(std::complex<double>) |
| 530 | + TEST_DTYPE(std::complex<long double>) |
| 531 | + TEST_DTYPE(int8_t) |
| 532 | + TEST_DTYPE(uint8_t) |
| 533 | + TEST_DTYPE(int16_t) |
| 534 | + TEST_DTYPE(uint16_t) |
| 535 | + TEST_DTYPE(int32_t) |
| 536 | + TEST_DTYPE(uint32_t) |
| 537 | + TEST_DTYPE(int64_t) |
| 538 | + TEST_DTYPE(uint64_t) |
| 539 | +#undef TEST_DTYPE |
| 540 | + return res; |
| 541 | + }); |
| 542 | + m.def("test_dtype_normalized_num", []() -> py::list { |
| 543 | + py::list res; |
| 544 | +#define TEST_DTYPE(NT, T) \ |
| 545 | + res.append(py::make_tuple(py::dtype(py::detail::npy_api::NT).normalized_num(), \ |
| 546 | + py::dtype::num_of<T>())); |
| 547 | + TEST_DTYPE(NPY_BOOL_, bool) |
| 548 | + TEST_DTYPE(NPY_BYTE_, char); |
| 549 | + TEST_DTYPE(NPY_UBYTE_, unsigned char); |
| 550 | + TEST_DTYPE(NPY_SHORT_, short); |
| 551 | + TEST_DTYPE(NPY_USHORT_, unsigned short); |
| 552 | + TEST_DTYPE(NPY_INT_, int); |
| 553 | + TEST_DTYPE(NPY_UINT_, unsigned int); |
| 554 | + TEST_DTYPE(NPY_LONG_, long); |
| 555 | + TEST_DTYPE(NPY_ULONG_, unsigned long); |
| 556 | + TEST_DTYPE(NPY_LONGLONG_, long long); |
| 557 | + TEST_DTYPE(NPY_ULONGLONG_, unsigned long long); |
| 558 | + TEST_DTYPE(NPY_FLOAT_, float); |
| 559 | + TEST_DTYPE(NPY_DOUBLE_, double); |
| 560 | + TEST_DTYPE(NPY_LONGDOUBLE_, long double); |
| 561 | + TEST_DTYPE(NPY_CFLOAT_, std::complex<float>); |
| 562 | + TEST_DTYPE(NPY_CDOUBLE_, std::complex<double>); |
| 563 | + TEST_DTYPE(NPY_CLONGDOUBLE_, std::complex<long double>); |
| 564 | + TEST_DTYPE(NPY_INT8_, int8_t); |
| 565 | + TEST_DTYPE(NPY_UINT8_, uint8_t); |
| 566 | + TEST_DTYPE(NPY_INT16_, int16_t); |
| 567 | + TEST_DTYPE(NPY_UINT16_, uint16_t); |
| 568 | + TEST_DTYPE(NPY_INT32_, int32_t); |
| 569 | + TEST_DTYPE(NPY_UINT32_, uint32_t); |
| 570 | + TEST_DTYPE(NPY_INT64_, int64_t); |
| 571 | + TEST_DTYPE(NPY_UINT64_, uint64_t); |
| 572 | +#undef TEST_DTYPE |
| 573 | + return res; |
| 574 | + }); |
| 575 | + m.def("test_dtype_switch", [](const py::array &arr) -> py::array { |
| 576 | + switch (arr.dtype().normalized_num()) { |
| 577 | + case py::dtype::num_of<int8_t>(): |
| 578 | + return dispatch_array_increment<int8_t>(arr); |
| 579 | + case py::dtype::num_of<uint8_t>(): |
| 580 | + return dispatch_array_increment<uint8_t>(arr); |
| 581 | + case py::dtype::num_of<int16_t>(): |
| 582 | + return dispatch_array_increment<int16_t>(arr); |
| 583 | + case py::dtype::num_of<uint16_t>(): |
| 584 | + return dispatch_array_increment<uint16_t>(arr); |
| 585 | + case py::dtype::num_of<int32_t>(): |
| 586 | + return dispatch_array_increment<int32_t>(arr); |
| 587 | + case py::dtype::num_of<uint32_t>(): |
| 588 | + return dispatch_array_increment<uint32_t>(arr); |
| 589 | + case py::dtype::num_of<int64_t>(): |
| 590 | + return dispatch_array_increment<int64_t>(arr); |
| 591 | + case py::dtype::num_of<uint64_t>(): |
| 592 | + return dispatch_array_increment<uint64_t>(arr); |
| 593 | + case py::dtype::num_of<float>(): |
| 594 | + return dispatch_array_increment<float>(arr); |
| 595 | + case py::dtype::num_of<double>(): |
| 596 | + return dispatch_array_increment<double>(arr); |
| 597 | + case py::dtype::num_of<long double>(): |
| 598 | + return dispatch_array_increment<long double>(arr); |
| 599 | + default: |
| 600 | + throw std::runtime_error("Unsupported dtype"); |
| 601 | + } |
| 602 | + }); |
499 | 603 | m.def("test_dtype_methods", []() { |
500 | 604 | py::list list; |
501 | 605 | auto dt1 = py::dtype::of<int32_t>(); |
|
0 commit comments