|
10 | 10 |
|
11 | 11 | np = pytest.importorskip("numpy") |
12 | 12 |
|
| 13 | +if m.long_double_and_double_have_same_size: |
| 14 | + # Determined by the compiler used to build the pybind11 tests |
| 15 | + # (e.g. MSVC gets here, but MinGW might not). |
| 16 | + np_float128 = None |
| 17 | + np_complex256 = None |
| 18 | +else: |
| 19 | + # Determined by the compiler used to build numpy (e.g. MinGW). |
| 20 | + np_float128 = getattr(np, *["float128"] * 2) |
| 21 | + np_complex256 = getattr(np, *["complex256"] * 2) |
| 22 | + |
| 23 | +CPP_NAME_FORMAT_NP_DTYPE_TABLE = [ |
| 24 | + ("PyObject *", "O", object), |
| 25 | + ("bool", "?", np.bool_), |
| 26 | + ("std::int8_t", "b", np.int8), |
| 27 | + ("std::uint8_t", "B", np.uint8), |
| 28 | + ("std::int16_t", "h", np.int16), |
| 29 | + ("std::uint16_t", "H", np.uint16), |
| 30 | + ("std::int32_t", "i", np.int32), |
| 31 | + ("std::uint32_t", "I", np.uint32), |
| 32 | + ("std::int64_t", "q", np.int64), |
| 33 | + ("std::uint64_t", "Q", np.uint64), |
| 34 | + ("float", "f", np.float32), |
| 35 | + ("double", "d", np.float64), |
| 36 | + ("long double", "g", np_float128), |
| 37 | + ("std::complex<float>", "Zf", np.complex64), |
| 38 | + ("std::complex<double>", "Zd", np.complex128), |
| 39 | + ("std::complex<long double>", "Zg", np_complex256), |
| 40 | +] |
| 41 | +CPP_NAME_FORMAT_TABLE = [ |
| 42 | + (cpp_name, format) |
| 43 | + for cpp_name, format, np_dtype in CPP_NAME_FORMAT_NP_DTYPE_TABLE |
| 44 | + if np_dtype is not None |
| 45 | +] |
| 46 | +CPP_NAME_NP_DTYPE_TABLE = [ |
| 47 | + (cpp_name, np_dtype) for cpp_name, _, np_dtype in CPP_NAME_FORMAT_NP_DTYPE_TABLE |
| 48 | +] |
| 49 | + |
| 50 | + |
| 51 | +@pytest.mark.parametrize(("cpp_name", "np_dtype"), CPP_NAME_NP_DTYPE_TABLE) |
| 52 | +def test_format_descriptor_format_buffer_info_equiv(cpp_name, np_dtype): |
| 53 | + if np_dtype is None: |
| 54 | + pytest.skip( |
| 55 | + f"cpp_name=`{cpp_name}`: `long double` and `double` have same size." |
| 56 | + ) |
| 57 | + if isinstance(np_dtype, str): |
| 58 | + pytest.skip(f"np.{np_dtype} does not exist.") |
| 59 | + np_array = np.array([], dtype=np_dtype) |
| 60 | + for other_cpp_name, expected_format in CPP_NAME_FORMAT_TABLE: |
| 61 | + format, np_array_is_matching = m.format_descriptor_format_buffer_info_equiv( |
| 62 | + other_cpp_name, np_array |
| 63 | + ) |
| 64 | + assert format == expected_format |
| 65 | + if other_cpp_name == cpp_name: |
| 66 | + assert np_array_is_matching |
| 67 | + else: |
| 68 | + assert not np_array_is_matching |
| 69 | + |
13 | 70 |
|
14 | 71 | def test_from_python(): |
15 | 72 | with pytest.raises(RuntimeError) as excinfo: |
|
0 commit comments