Skip to content

Commit fc12fa5

Browse files
rongoualliepiper
authored andcommitted
fix transform_inclusive_scan with different value types
1 parent a0948e3 commit fc12fa5

File tree

3 files changed

+59
-4
lines changed

3 files changed

+59
-4
lines changed

testing/transform_scan.cu

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,61 @@ void TestTransformScanSimple(void)
190190
}
191191
DECLARE_INTEGRAL_VECTOR_UNITTEST(TestTransformScanSimple);
192192

193+
struct Record {
194+
int number;
195+
196+
bool operator==(const Record& rhs) const {
197+
return number == rhs.number;
198+
}
199+
bool operator!=(const Record& rhs) const {
200+
return !(rhs == *this);
201+
}
202+
friend Record operator+(Record lhs, const Record& rhs) {
203+
lhs.number += rhs.number;
204+
return lhs;
205+
}
206+
friend std::ostream& operator<<(std::ostream& os, const Record& record) {
207+
os << "number: " << record.number;
208+
return os;
209+
}
210+
};
211+
212+
struct negate {
213+
__host__ __device__ int operator()(Record const& record) const
214+
{
215+
return - record.number;
216+
}
217+
};
218+
219+
void TestTransformInclusiveScanDifferentTypes()
220+
{
221+
typename thrust::host_vector<int>::iterator h_iter;
222+
223+
thrust::host_vector<Record> h_input(5);
224+
thrust::host_vector<int> h_output(5);
225+
thrust::host_vector<int> result(5);
226+
227+
h_input[0] = {1}; h_input[1] = {3}; h_input[2] = {-2}; h_input[3] = {4}; h_input[4] = {-5};
228+
229+
thrust::host_vector<Record> input_copy(h_input);
230+
231+
h_iter = thrust::transform_inclusive_scan(h_input.begin(), h_input.end(), h_output.begin(), negate{}, thrust::plus<int>{});
232+
result[0] = -1; result[1] = -4; result[2] = -2; result[3] = -6; result[4] = -1;
233+
ASSERT_EQUAL(std::size_t(h_iter - h_output.begin()), h_input.size());
234+
ASSERT_EQUAL(h_input, input_copy);
235+
ASSERT_EQUAL(h_output, result);
236+
237+
typename thrust::device_vector<int>::iterator d_iter;
238+
239+
thrust::device_vector<Record> d_input = h_input;
240+
thrust::device_vector<int> d_output(5);
241+
242+
d_iter = thrust::transform_inclusive_scan(d_input.begin(), d_input.end(), d_output.begin(), negate{}, thrust::plus<int>{});
243+
ASSERT_EQUAL(std::size_t(d_iter - d_output.begin()), d_input.size());
244+
ASSERT_EQUAL(d_input, input_copy);
245+
ASSERT_EQUAL(d_output, result);
246+
}
247+
DECLARE_UNITTEST(TestTransformInclusiveScanDifferentTypes);
193248

194249
template <typename T>
195250
struct TestTransformScan

thrust/system/cuda/detail/transform_scan.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,8 +50,8 @@ transform_inclusive_scan(execution_policy<Derived> &policy,
5050
TransformOp transform_op,
5151
ScanOp scan_op)
5252
{
53-
// Use the input iterator's value type per https://wg21.link/P0571
54-
using result_type = typename thrust::iterator_value<InputIt>::type;
53+
using input_type = typename thrust::iterator_value<InputIt>::type;
54+
using result_type = typename std::result_of<TransformOp(input_type)>::type;
5555

5656
typedef typename iterator_traits<InputIt>::difference_type size_type;
5757
size_type num_items = static_cast<size_type>(thrust::distance(first, last));

thrust/system/detail/generic/transform_scan.inl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,8 +48,8 @@ __host__ __device__
4848
UnaryFunction unary_op,
4949
BinaryFunction binary_op)
5050
{
51-
// Use the input iterator's value type per https://wg21.link/P0571
52-
using ValueType = typename thrust::iterator_value<InputIterator>::type;
51+
using InputType = typename thrust::iterator_value<InputIterator>::type;
52+
using ValueType = typename std::result_of<UnaryFunction(InputType)>::type;
5353

5454
thrust::transform_iterator<UnaryFunction, InputIterator, ValueType> _first(first, unary_op);
5555
thrust::transform_iterator<UnaryFunction, InputIterator, ValueType> _last(last, unary_op);

0 commit comments

Comments
 (0)