Skip to content

Commit 1d16811

Browse files
Francis Lemairegriwes
authored andcommitted
The great Thrust index type fix, part 9: exclusive_scan, inclusive_scan.
1 parent 01bbe09 commit 1d16811

File tree

2 files changed

+122
-20
lines changed

2 files changed

+122
-20
lines changed

testing/scan.cu

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
#include <thrust/functional.h>
44
#include <thrust/iterator/discard_iterator.h>
55
#include <thrust/iterator/retag.h>
6+
#include <thrust/device_malloc.h>
7+
#include <thrust/device_free.h>
68

79

810
template<typename T>
@@ -555,3 +557,93 @@ void TestInclusiveScanWithIndirection(void)
555557
}
556558
DECLARE_INTEGRAL_VECTOR_UNITTEST(TestInclusiveScanWithIndirection);
557559

560+
struct only_set_when_expected_it
561+
{
562+
long long expected;
563+
bool * flag;
564+
565+
__host__ __device__ only_set_when_expected_it operator++() const { return *this; }
566+
__host__ __device__ only_set_when_expected_it operator*() const { return *this; }
567+
template<typename Difference>
568+
__host__ __device__ only_set_when_expected_it operator+(Difference) const { return *this; }
569+
template<typename Index>
570+
__host__ __device__ only_set_when_expected_it operator[](Index) const { return *this; }
571+
572+
__device__
573+
void operator=(long long value) const
574+
{
575+
if (value == expected)
576+
{
577+
*flag = true;
578+
}
579+
}
580+
};
581+
582+
namespace thrust
583+
{
584+
template<>
585+
struct iterator_traits<only_set_when_expected_it>
586+
{
587+
typedef long long value_type;
588+
typedef only_set_when_expected_it reference;
589+
};
590+
}
591+
592+
void TestInclusiveScanWithBigIndexesHelper(int magnitude)
593+
{
594+
thrust::constant_iterator<long long> begin(1);
595+
thrust::constant_iterator<long long> end = begin + (1ll << magnitude);
596+
ASSERT_EQUAL(thrust::distance(begin, end), 1ll << magnitude);
597+
598+
thrust::device_ptr<bool> has_executed = thrust::device_malloc<bool>(1);
599+
*has_executed = false;
600+
601+
only_set_when_expected_it out = { (1ll << magnitude), thrust::raw_pointer_cast(has_executed) };
602+
603+
thrust::inclusive_scan(thrust::device, begin, end, out);
604+
605+
bool has_executed_h = *has_executed;
606+
thrust::device_free(has_executed);
607+
608+
ASSERT_EQUAL(has_executed_h, true);
609+
}
610+
611+
void TestInclusiveScanWithBigIndexes()
612+
{
613+
TestInclusiveScanWithBigIndexesHelper(30);
614+
TestInclusiveScanWithBigIndexesHelper(31);
615+
TestInclusiveScanWithBigIndexesHelper(32);
616+
TestInclusiveScanWithBigIndexesHelper(33);
617+
}
618+
619+
DECLARE_UNITTEST(TestInclusiveScanWithBigIndexes);
620+
621+
void TestExclusiveScanWithBigIndexesHelper(int magnitude)
622+
{
623+
thrust::constant_iterator<long long> begin(1);
624+
thrust::constant_iterator<long long> end = begin + (1ll << magnitude);
625+
ASSERT_EQUAL(thrust::distance(begin, end), 1ll << magnitude);
626+
627+
thrust::device_ptr<bool> has_executed = thrust::device_malloc<bool>(1);
628+
*has_executed = false;
629+
630+
only_set_when_expected_it out = { (1ll << magnitude) - 1, thrust::raw_pointer_cast(has_executed) };
631+
632+
thrust::exclusive_scan(thrust::device, begin, end, out,0ll);
633+
634+
bool has_executed_h = *has_executed;
635+
thrust::device_free(has_executed);
636+
637+
ASSERT_EQUAL(has_executed_h, true);
638+
}
639+
640+
void TestExclusiveScanWithBigIndexes()
641+
{
642+
TestExclusiveScanWithBigIndexesHelper(30);
643+
TestExclusiveScanWithBigIndexesHelper(31);
644+
TestExclusiveScanWithBigIndexesHelper(32);
645+
TestExclusiveScanWithBigIndexesHelper(33);
646+
}
647+
648+
DECLARE_UNITTEST(TestExclusiveScanWithBigIndexes);
649+

thrust/system/cuda/detail/scan.h

Lines changed: 30 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -40,9 +40,11 @@
4040
#include <cub/device/device_scan.cuh>
4141
#include <thrust/system/cuda/detail/core/agent_launcher.h>
4242
#include <thrust/system/cuda/detail/par_to_seq.h>
43+
#include <thrust/system/cuda/detail/dispatch.h>
4344
#include <thrust/detail/mpl/math.h>
4445
#include <thrust/detail/minmax.h>
4546
#include <thrust/distance.h>
47+
#include <thrust/iterator/iterator_traits.h>
4648

4749
THRUST_BEGIN_NS
4850
template <typename DerivedPolicy,
@@ -710,31 +712,37 @@ namespace __scan {
710712
bool debug_sync = THRUST_DEBUG_SYNC_FLAG;
711713

712714
cudaError_t status;
713-
status = doit_step<Inclusive>(NULL,
714-
storage_size,
715-
input_it,
716-
num_items,
717-
add_init_to_exclusive_scan,
718-
output_it,
719-
scan_op,
720-
stream,
721-
debug_sync);
715+
THRUST_INDEX_TYPE_DISPATCH(status,
716+
doit_step<Inclusive>,
717+
num_items,
718+
(NULL,
719+
storage_size,
720+
input_it,
721+
num_items_fixed,
722+
add_init_to_exclusive_scan,
723+
output_it,
724+
scan_op,
725+
stream,
726+
debug_sync));
722727
cuda_cub::throw_on_error(status, "scan failed on 1st step");
723728

724729
// Allocate temporary storage.
725730
thrust::detail::temporary_array<thrust::detail::uint8_t, Derived>
726731
tmp(policy, storage_size);
727732
void *ptr = static_cast<void*>(tmp.data().get());
728733

729-
status = doit_step<Inclusive>(ptr,
730-
storage_size,
731-
input_it,
732-
num_items,
733-
add_init_to_exclusive_scan,
734-
output_it,
735-
scan_op,
736-
stream,
737-
debug_sync);
734+
THRUST_INDEX_TYPE_DISPATCH(status,
735+
doit_step<Inclusive>,
736+
num_items,
737+
(ptr,
738+
storage_size,
739+
input_it,
740+
num_items_fixed,
741+
add_init_to_exclusive_scan,
742+
output_it,
743+
scan_op,
744+
stream,
745+
debug_sync));
738746
cuda_cub::throw_on_error(status, "scan failed on 2nd step");
739747

740748
status = cuda_cub::synchronize(policy);
@@ -798,7 +806,8 @@ inclusive_scan(execution_policy<Derived> &policy,
798806
OutputIt result,
799807
ScanOp scan_op)
800808
{
801-
int num_items = static_cast<int>(thrust::distance(first, last));
809+
typedef typename thrust::iterator_traits<InputIt>::difference_type diff_t;
810+
diff_t num_items = thrust::distance(first, last);
802811
return cuda_cub::inclusive_scan_n(policy, first, num_items, result, scan_op);
803812
}
804813

@@ -873,7 +882,8 @@ exclusive_scan(execution_policy<Derived> &policy,
873882
T init,
874883
ScanOp scan_op)
875884
{
876-
int num_items = static_cast<int>(thrust::distance(first, last));
885+
typedef typename thrust::iterator_traits<InputIt>::difference_type diff_t;
886+
diff_t num_items = thrust::distance(first, last);
877887
return cuda_cub::exclusive_scan_n(policy, first, num_items, result, init, scan_op);
878888
}
879889

0 commit comments

Comments
 (0)