3838#include < datetime.h> // Python datetime builtin.
3939
4040#include < cmath>
41+ #include < complex>
4142#include < cstdint>
4243#include < tuple>
4344#include < type_traits>
@@ -385,18 +386,46 @@ template <>
385386struct type_caster <absl::CivilYear>
386387 : public absl_civil_date_caster<absl::CivilYear> {};
387388
389+ // Using internal namespace to avoid name collisons in case this code is
390+ // accepted upsteam (pybind11).
391+ namespace internal {
392+
393+ template <typename T>
394+ static constexpr bool is_buffer_interface_compatible_type =
395+ std::is_arithmetic<T>::value ||
396+ std::is_same<T, std::complex <float >>::value ||
397+ std::is_same<T, std::complex <double >>::value;
398+
399+ template <typename T, typename SFINAE = void >
400+ struct format_descriptor_char2 {
401+ static constexpr const char c = ' \0 ' ;
402+ };
403+
404+ template <typename T>
405+ struct format_descriptor_char2 <std::complex <T>> : format_descriptor<T> {};
406+
407+ template <typename T>
408+ inline bool buffer_view_matches_format_descriptor (const char * view_format) {
409+ return view_format[0 ] == format_descriptor<T>::c ||
410+ (view_format[0 ] == ' Z' &&
411+ view_format[1 ] == format_descriptor_char2<T>::c);
412+ }
413+
414+ } // namespace internal
415+
388416// Returns {true, a span referencing the data contained by src} without copying
389417// or converting the data if possible. Otherwise returns {false, an empty span}.
390- template <typename T, typename std::enable_if<std::is_arithmetic<T>::value,
391- bool >::type = true >
418+ template <typename T, typename std::enable_if<
419+ internal::is_buffer_interface_compatible_type<T>,
420+ bool >::type = true >
392421std::tuple<bool , absl::Span<T>> LoadSpanFromBuffer (handle src) {
393422 Py_buffer view;
394423 int flags = PyBUF_STRIDES | PyBUF_FORMAT;
395424 if (!std::is_const<T>::value) flags |= PyBUF_WRITABLE;
396425 if (PyObject_GetBuffer (src.ptr (), &view, flags) == 0 ) {
397426 auto cleanup = absl::MakeCleanup ([&view] { PyBuffer_Release (&view); });
398427 if (view.ndim == 1 && view.strides [0 ] == sizeof (T) &&
399- view. format [ 0 ] == format_descriptor <T>::c ) {
428+ internal::buffer_view_matches_format_descriptor <T>(view. format ) ) {
400429 return {true , absl::MakeSpan (static_cast <T*>(view.buf ), view.shape [0 ])};
401430 }
402431 } else {
@@ -405,9 +434,9 @@ std::tuple<bool, absl::Span<T>> LoadSpanFromBuffer(handle src) {
405434 }
406435 return {false , absl::Span<T>()};
407436}
408- // If T is not a numeric type, the buffer interface cannot be used.
409- template < typename T, typename std::enable_if<!std::is_arithmetic <T>::value ,
410- bool >::type = true >
437+ template < typename T, typename std::enable_if<
438+ !internal::is_buffer_interface_compatible_type <T>,
439+ bool >::type = true >
411440constexpr std::tuple<bool , absl::Span<T>> LoadSpanFromBuffer (handle src) {
412441 return {false , absl::Span<T>()};
413442}
0 commit comments