@@ -190,6 +190,61 @@ void TestTransformScanSimple(void)
190190}
191191DECLARE_INTEGRAL_VECTOR_UNITTEST (TestTransformScanSimple);
192192
193+ struct Record {
194+ int number;
195+
196+ bool operator ==(const Record& rhs) const {
197+ return number == rhs.number ;
198+ }
199+ bool operator !=(const Record& rhs) const {
200+ return !(rhs == *this );
201+ }
202+ friend Record operator +(Record lhs, const Record& rhs) {
203+ lhs.number += rhs.number ;
204+ return lhs;
205+ }
206+ friend std::ostream& operator <<(std::ostream& os, const Record& record) {
207+ os << " number: " << record.number ;
208+ return os;
209+ }
210+ };
211+
212+ struct negate {
213+ __host__ __device__ int operator ()(Record const & record) const
214+ {
215+ return - record.number ;
216+ }
217+ };
218+
219+ void TestTransformInclusiveScanDifferentTypes ()
220+ {
221+ typename thrust::host_vector<int >::iterator h_iter;
222+
223+ thrust::host_vector<Record> h_input (5 );
224+ thrust::host_vector<int > h_output (5 );
225+ thrust::host_vector<int > result (5 );
226+
227+ h_input[0 ] = {1 }; h_input[1 ] = {3 }; h_input[2 ] = {-2 }; h_input[3 ] = {4 }; h_input[4 ] = {-5 };
228+
229+ thrust::host_vector<Record> input_copy (h_input);
230+
231+ h_iter = thrust::transform_inclusive_scan (h_input.begin (), h_input.end (), h_output.begin (), negate{}, thrust::plus<int >{});
232+ result[0 ] = -1 ; result[1 ] = -4 ; result[2 ] = -2 ; result[3 ] = -6 ; result[4 ] = -1 ;
233+ ASSERT_EQUAL (std::size_t (h_iter - h_output.begin ()), h_input.size ());
234+ ASSERT_EQUAL (h_input, input_copy);
235+ ASSERT_EQUAL (h_output, result);
236+
237+ typename thrust::device_vector<int >::iterator d_iter;
238+
239+ thrust::device_vector<Record> d_input = h_input;
240+ thrust::device_vector<int > d_output (5 );
241+
242+ d_iter = thrust::transform_inclusive_scan (d_input.begin (), d_input.end (), d_output.begin (), negate{}, thrust::plus<int >{});
243+ ASSERT_EQUAL (std::size_t (d_iter - d_output.begin ()), d_input.size ());
244+ ASSERT_EQUAL (d_input, input_copy);
245+ ASSERT_EQUAL (d_output, result);
246+ }
247+ DECLARE_UNITTEST (TestTransformInclusiveScanDifferentTypes);
193248
194249template <typename T>
195250struct TestTransformScan
0 commit comments