Skip to content

Commit 3af276e

Browse files
committed
Make thrust::/std::complex interop __device__ qualified for C++11+.
The functions to construct, assign and compare thrust::complex values from and with std::complex values were marked __host__ since forever, because access to std::complex is performed using member functions. However, in C++11, an explicit permission has been given to reinterpret_cast std::complex values as arrays of two elements of its template parameter, allowing us to implement a __device__-compatible set of those interop functions, when compiling for C++11. For C++03, they are still only __host__-qualified. Bug 2502854
1 parent 63d847b commit 3af276e

File tree

3 files changed

+123
-79
lines changed

3 files changed

+123
-79
lines changed

testing/complex.cu

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -284,3 +284,27 @@ struct TestComplexStreamOperators
284284
};
285285

286286
SimpleUnitTest<TestComplexStreamOperators, FloatingPointTypes> TestComplexStreamOperatorsInstance;
287+
288+
#if THRUST_CPP_DIALECT >= 2011
289+
template<typename T>
290+
struct TestComplexStdComplexDeviceInterop
291+
{
292+
void operator()()
293+
{
294+
thrust::host_vector<T> data = unittest::random_samples<T>(6);
295+
std::vector<std::complex<T> > vec(10);
296+
vec[0] = std::complex<T>(data[0], data[1]);
297+
vec[1] = std::complex<T>(data[2], data[3]);
298+
vec[2] = std::complex<T>(data[4], data[5]);
299+
300+
thrust::device_vector<thrust::complex<T> > device_vec = vec;
301+
ASSERT_ALMOST_EQUAL(vec[0].real(), thrust::complex<T>(device_vec[0]).real());
302+
ASSERT_ALMOST_EQUAL(vec[0].imag(), thrust::complex<T>(device_vec[0]).imag());
303+
ASSERT_ALMOST_EQUAL(vec[1].real(), thrust::complex<T>(device_vec[1]).real());
304+
ASSERT_ALMOST_EQUAL(vec[1].imag(), thrust::complex<T>(device_vec[1]).imag());
305+
ASSERT_ALMOST_EQUAL(vec[2].real(), thrust::complex<T>(device_vec[2]).real());
306+
ASSERT_ALMOST_EQUAL(vec[2].imag(), thrust::complex<T>(device_vec[2]).imag());
307+
}
308+
};
309+
SimpleUnitTest<TestComplexStdComplexDeviceInterop, FloatingPointTypes> TestComplexStdComplexDeviceInteropInstance;
310+
#endif

thrust/complex.h

Lines changed: 56 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2008-2018 NVIDIA Corporation
2+
* Copyright 2008-2019 NVIDIA Corporation
33
* Copyright 2013 Filipe RNC Maia
44
*
55
* Licensed under the Apache License, Version 2.0 (the "License");
@@ -28,11 +28,27 @@
2828
#include <sstream>
2929
#include <thrust/detail/type_traits.h>
3030

31+
#if THRUST_CPP_DIALECT >= 2011
32+
# define THRUST_STD_COMPLEX_REAL(z) \
33+
reinterpret_cast< \
34+
const typename thrust::detail::remove_reference<decltype(z)>::type::value_type (&)[2] \
35+
>(z)[0]
36+
# define THRUST_STD_COMPLEX_IMAG(z) \
37+
reinterpret_cast< \
38+
const typename thrust::detail::remove_reference<decltype(z)>::type::value_type (&)[2] \
39+
>(z)[1]
40+
# define THRUST_STD_COMPLEX_DEVICE __device__
41+
#else
42+
# define THRUST_STD_COMPLEX_REAL(z) (z).real()
43+
# define THRUST_STD_COMPLEX_IMAG(z) (z).imag()
44+
# define THRUST_STD_COMPLEX_DEVICE
45+
#endif
46+
3147
namespace thrust
3248
{
3349

3450
/*
35-
* Calls to the standard math library from inside the thrust namespace
51+
* Calls to the standard math library from inside the thrust namespace
3652
* with real arguments require explicit scope otherwise they will fail
3753
* to resolve as it will find the equivalent complex function but then
3854
* fail to match the template, and give up looking for other scopes.
@@ -125,7 +141,7 @@ struct complex
125141
*
126142
* \param z The \p complex to copy from.
127143
*/
128-
__host__
144+
__host__ THRUST_STD_COMPLEX_DEVICE
129145
complex(const std::complex<T>& z);
130146

131147
/*! This converting copy constructor copies from a <tt>std::complex</tt> with
@@ -136,7 +152,7 @@ struct complex
136152
* \tparam U is convertible to \c value_type.
137153
*/
138154
template <typename U>
139-
__host__
155+
__host__ THRUST_STD_COMPLEX_DEVICE
140156
complex(const std::complex<U>& z);
141157

142158

@@ -184,7 +200,7 @@ struct complex
184200
*
185201
* \param z The \p complex to copy from.
186202
*/
187-
__host__
203+
__host__ THRUST_STD_COMPLEX_DEVICE
188204
complex& operator=(const std::complex<T>& z);
189205

190206
/*! Assign `z.real()` and `z.imag()` to the real and imaginary parts of this
@@ -195,7 +211,7 @@ struct complex
195211
* \tparam U is convertible to \c value_type.
196212
*/
197213
template <typename U>
198-
__host__
214+
__host__ THRUST_STD_COMPLEX_DEVICE
199215
complex& operator=(const std::complex<U>& z);
200216

201217

@@ -205,7 +221,7 @@ struct complex
205221
* \p complex.
206222
*
207223
* \param z The \p complex to be added.
208-
*
224+
*
209225
* \tparam U is convertible to \c value_type.
210226
*/
211227
template <typename U>
@@ -269,7 +285,7 @@ struct complex
269285

270286
/*! Multiplies this \p complex by a scalar and assigns the result
271287
* to this \p complex.
272-
*
288+
*
273289
* \param z The scalar to be multiplied.
274290
*
275291
* \tparam U is convertible to \c value_type.
@@ -280,7 +296,7 @@ struct complex
280296

281297
/*! Divides this \p complex by a scalar and assigns the result to
282298
* this \p complex.
283-
*
299+
*
284300
* \param z The scalar to be divided.
285301
*
286302
* \tparam U is convertible to \c value_type.
@@ -291,7 +307,7 @@ struct complex
291307

292308

293309

294-
/* --- Getter functions ---
310+
/* --- Getter functions ---
295311
* The volatile ones are there to help for example
296312
* with certain reductions optimizations
297313
*/
@@ -318,7 +334,7 @@ struct complex
318334

319335

320336

321-
/* --- Setter functions ---
337+
/* --- Setter functions ---
322338
* The volatile ones are there to help for example
323339
* with certain reductions optimizations
324340
*/
@@ -434,8 +450,8 @@ complex<typename detail::promoted_numerical_type<T0, T1>::type>
434450
polar(const T0& m, const T1& theta = T1());
435451

436452
/*! Returns the projection of a \p complex on the Riemann sphere.
437-
* For all finite \p complex it returns the argument. For \p complexs
438-
* with a non finite part returns (INFINITY,+/-0) where the sign of
453+
* For all finite \p complex it returns the argument. For \p complexs
454+
* with a non finite part returns (INFINITY,+/-0) where the sign of
439455
* the zero matches the sign of the imaginary part of the argument.
440456
*
441457
* \param z The \p complex argument.
@@ -449,7 +465,7 @@ complex<T> proj(const T& z);
449465
/* --- Binary Arithmetic operators --- */
450466

451467
/*! Adds two \p complex numbers.
452-
*
468+
*
453469
* The value types of the two \p complex types should be compatible and the
454470
* type of the returned \p complex is the promoted type of the two arguments.
455471
*
@@ -462,7 +478,7 @@ complex<typename detail::promoted_numerical_type<T0, T1>::type>
462478
operator+(const complex<T0>& x, const complex<T1>& y);
463479

464480
/*! Adds a scalar to a \p complex number.
465-
*
481+
*
466482
* The value type of the \p complex should be compatible with the scalar and
467483
* the type of the returned \p complex is the promoted type of the two arguments.
468484
*
@@ -475,7 +491,7 @@ complex<typename detail::promoted_numerical_type<T0, T1>::type>
475491
operator+(const complex<T0>& x, const T1& y);
476492

477493
/*! Adds a \p complex number to a scalar.
478-
*
494+
*
479495
* The value type of the \p complex should be compatible with the scalar and
480496
* the type of the returned \p complex is the promoted type of the two arguments.
481497
*
@@ -488,7 +504,7 @@ complex<typename detail::promoted_numerical_type<T0, T1>::type>
488504
operator+(const T0& x, const complex<T1>& y);
489505

490506
/*! Subtracts two \p complex numbers.
491-
*
507+
*
492508
* The value types of the two \p complex types should be compatible and the
493509
* type of the returned \p complex is the promoted type of the two arguments.
494510
*
@@ -501,7 +517,7 @@ complex<typename detail::promoted_numerical_type<T0, T1>::type>
501517
operator-(const complex<T0>& x, const complex<T1>& y);
502518

503519
/*! Subtracts a scalar from a \p complex number.
504-
*
520+
*
505521
* The value type of the \p complex should be compatible with the scalar and
506522
* the type of the returned \p complex is the promoted type of the two arguments.
507523
*
@@ -514,7 +530,7 @@ complex<typename detail::promoted_numerical_type<T0, T1>::type>
514530
operator-(const complex<T0>& x, const T1& y);
515531

516532
/*! Subtracts a \p complex number from a scalar.
517-
*
533+
*
518534
* The value type of the \p complex should be compatible with the scalar and
519535
* the type of the returned \p complex is the promoted type of the two arguments.
520536
*
@@ -527,7 +543,7 @@ complex<typename detail::promoted_numerical_type<T0, T1>::type>
527543
operator-(const T0& x, const complex<T1>& y);
528544

529545
/*! Multiplies two \p complex numbers.
530-
*
546+
*
531547
* The value types of the two \p complex types should be compatible and the
532548
* type of the returned \p complex is the promoted type of the two arguments.
533549
*
@@ -550,7 +566,7 @@ complex<typename detail::promoted_numerical_type<T0, T1>::type>
550566
operator*(const complex<T0>& x, const T1& y);
551567

552568
/*! Multiplies a scalar by a \p complex number.
553-
*
569+
*
554570
* The value type of the \p complex should be compatible with the scalar and
555571
* the type of the returned \p complex is the promoted type of the two arguments.
556572
*
@@ -563,7 +579,7 @@ complex<typename detail::promoted_numerical_type<T0, T1>::type>
563579
operator*(const T0& x, const complex<T1>& y);
564580

565581
/*! Divides two \p complex numbers.
566-
*
582+
*
567583
* The value types of the two \p complex types should be compatible and the
568584
* type of the returned \p complex is the promoted type of the two arguments.
569585
*
@@ -576,7 +592,7 @@ complex<typename detail::promoted_numerical_type<T0, T1>::type>
576592
operator/(const complex<T0>& x, const complex<T1>& y);
577593

578594
/*! Divides a \p complex number by a scalar.
579-
*
595+
*
580596
* The value type of the \p complex should be compatible with the scalar and
581597
* the type of the returned \p complex is the promoted type of the two arguments.
582598
*
@@ -589,7 +605,7 @@ complex<typename detail::promoted_numerical_type<T0, T1>::type>
589605
operator/(const complex<T0>& x, const T1& y);
590606

591607
/*! Divides a scalar by a \p complex number.
592-
*
608+
*
593609
* The value type of the \p complex should be compatible with the scalar and
594610
* the type of the returned \p complex is the promoted type of the two arguments.
595611
*
@@ -657,7 +673,7 @@ complex<T> log10(const complex<T>& z);
657673
/* --- Power Functions --- */
658674

659675
/*! Returns a \p complex number raised to another.
660-
*
676+
*
661677
* The value types of the two \p complex types should be compatible and the
662678
* type of the returned \p complex is the promoted type of the two arguments.
663679
*
@@ -764,7 +780,7 @@ complex<T> tanh(const complex<T>& z);
764780

765781
/*! Returns the complex arc cosine of a \p complex number.
766782
*
767-
* The range of the real part of the result is [0, Pi] and
783+
* The range of the real part of the result is [0, Pi] and
768784
* the range of the imaginary part is [-inf, +inf]
769785
*
770786
* \param z The \p complex argument.
@@ -775,7 +791,7 @@ complex<T> acos(const complex<T>& z);
775791

776792
/*! Returns the complex arc sine of a \p complex number.
777793
*
778-
* The range of the real part of the result is [-Pi/2, Pi/2] and
794+
* The range of the real part of the result is [-Pi/2, Pi/2] and
779795
* the range of the imaginary part is [-inf, +inf]
780796
*
781797
* \param z The \p complex argument.
@@ -786,7 +802,7 @@ complex<T> asin(const complex<T>& z);
786802

787803
/*! Returns the complex arc tangent of a \p complex number.
788804
*
789-
* The range of the real part of the result is [-Pi/2, Pi/2] and
805+
* The range of the real part of the result is [-Pi/2, Pi/2] and
790806
* the range of the imaginary part is [-inf, +inf]
791807
*
792808
* \param z The \p complex argument.
@@ -801,7 +817,7 @@ complex<T> atan(const complex<T>& z);
801817

802818
/*! Returns the complex inverse hyperbolic cosine of a \p complex number.
803819
*
804-
* The range of the real part of the result is [0, +inf] and
820+
* The range of the real part of the result is [0, +inf] and
805821
* the range of the imaginary part is [-Pi, Pi]
806822
*
807823
* \param z The \p complex argument.
@@ -812,7 +828,7 @@ complex<T> acosh(const complex<T>& z);
812828

813829
/*! Returns the complex inverse hyperbolic sine of a \p complex number.
814830
*
815-
* The range of the real part of the result is [-inf, +inf] and
831+
* The range of the real part of the result is [-inf, +inf] and
816832
* the range of the imaginary part is [-Pi/2, Pi/2]
817833
*
818834
* \param z The \p complex argument.
@@ -823,7 +839,7 @@ complex<T> asinh(const complex<T>& z);
823839

824840
/*! Returns the complex inverse hyperbolic tangent of a \p complex number.
825841
*
826-
* The range of the real part of the result is [-inf, +inf] and
842+
* The range of the real part of the result is [-inf, +inf] and
827843
* the range of the imaginary part is [-Pi/2, Pi/2]
828844
*
829845
* \param z The \p complex argument.
@@ -852,7 +868,7 @@ operator<<(std::basic_ostream<CharT, Traits>& os, const complex<T>& z);
852868
* - (real)
853869
* - (real, imaginary)
854870
*
855-
* The values read must be convertible to the \p complex's \c value_type
871+
* The values read must be convertible to the \p complex's \c value_type
856872
*
857873
* \param is The input stream.
858874
* \param z The \p complex number to set.
@@ -881,7 +897,7 @@ bool operator==(const complex<T0>& x, const complex<T1>& y);
881897
* \param y The second \p complex.
882898
*/
883899
template <typename T0, typename T1>
884-
__host__
900+
__host__ THRUST_STD_COMPLEX_DEVICE
885901
bool operator==(const complex<T0>& x, const std::complex<T1>& y);
886902

887903
/*! Returns true if two \p complex numbers are equal and false otherwise.
@@ -890,7 +906,7 @@ bool operator==(const complex<T0>& x, const std::complex<T1>& y);
890906
* \param y The second \p complex.
891907
*/
892908
template <typename T0, typename T1>
893-
__host__
909+
__host__ THRUST_STD_COMPLEX_DEVICE
894910
bool operator==(const std::complex<T0>& x, const complex<T1>& y);
895911

896912
/*! Returns true if the imaginary part of the \p complex number is zero and
@@ -928,7 +944,7 @@ bool operator!=(const complex<T0>& x, const complex<T1>& y);
928944
* \param y The second \p complex.
929945
*/
930946
template <typename T0, typename T1>
931-
__host__
947+
__host__ THRUST_STD_COMPLEX_DEVICE
932948
bool operator!=(const complex<T0>& x, const std::complex<T1>& y);
933949

934950
/*! Returns true if two \p complex numbers are different and false otherwise.
@@ -937,7 +953,7 @@ bool operator!=(const complex<T0>& x, const std::complex<T1>& y);
937953
* \param y The second \p complex.
938954
*/
939955
template <typename T0, typename T1>
940-
__host__
956+
__host__ THRUST_STD_COMPLEX_DEVICE
941957
bool operator!=(const std::complex<T0>& x, const complex<T1>& y);
942958

943959
/*! Returns true if the imaginary part of the \p complex number is not zero or
@@ -964,6 +980,10 @@ bool operator!=(const complex<T0>& x, const T1& y);
964980

965981
#include <thrust/detail/complex/complex.inl>
966982

983+
#undef THRUST_STD_COMPLEX_REAL
984+
#undef THRUST_STD_COMPLEX_IMAG
985+
#undef THRUST_STD_COMPLEX_DEVICE
986+
967987
/*! \} // complex_numbers
968988
*/
969989

0 commit comments

Comments
 (0)