Skip to content

Commit 086613d

Browse files
committed
The great Thrust index type fix, part 8: set operations.
1 parent c66f76e commit 086613d

File tree

4 files changed

+114
-40
lines changed

4 files changed

+114
-40
lines changed

testing/set_difference.cu

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -211,3 +211,28 @@ void TestSetDifferenceMultiset(const size_t n)
211211
}
212212
DECLARE_VARIABLE_UNITTEST(TestSetDifferenceMultiset);
213213

214+
void TestSetDifferenceWithBigIndexesHelper(int magnitude)
215+
{
216+
thrust::counting_iterator<long long> begin(0);
217+
thrust::counting_iterator<long long> end = begin + (1ll << magnitude);
218+
thrust::counting_iterator<long long> end_longer = end + 1;
219+
ASSERT_EQUAL(thrust::distance(begin, end), 1ll << magnitude);
220+
221+
thrust::device_vector<long long> result;
222+
result.resize(1);
223+
thrust::set_difference(thrust::device, begin, end_longer, begin, end, result.begin());
224+
225+
thrust::host_vector<long long> expected;
226+
expected.push_back(*end);
227+
228+
ASSERT_EQUAL(result, expected);
229+
}
230+
231+
void TestSetDifferenceWithBigIndexes()
232+
{
233+
TestSetDifferenceWithBigIndexesHelper(30);
234+
TestSetDifferenceWithBigIndexesHelper(31);
235+
TestSetDifferenceWithBigIndexesHelper(32);
236+
TestSetDifferenceWithBigIndexesHelper(33);
237+
}
238+
DECLARE_UNITTEST(TestSetDifferenceWithBigIndexes);

testing/set_intersection.cu

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -251,3 +251,29 @@ void TestSetIntersectionMultiset(const size_t n)
251251
}
252252
DECLARE_VARIABLE_UNITTEST(TestSetIntersectionMultiset);
253253

254+
void TestSetDifferenceWithBigIndexesHelper(int magnitude)
255+
{
256+
thrust::counting_iterator<long long> begin1(0);
257+
thrust::counting_iterator<long long> begin2 = begin1 + (1ll << magnitude);
258+
thrust::counting_iterator<long long> end1 = begin2 + 1;
259+
thrust::counting_iterator<long long> end2 = begin2 + (1ll << magnitude);
260+
ASSERT_EQUAL(thrust::distance(begin2, end1), 1);
261+
262+
thrust::device_vector<long long> result;
263+
result.resize(1);
264+
thrust::set_intersection(thrust::device, begin1, end1, begin2, end2, result.begin());
265+
266+
thrust::host_vector<long long> expected;
267+
expected.push_back(*begin2);
268+
269+
ASSERT_EQUAL(result, expected);
270+
}
271+
272+
void TestSetDifferenceWithBigIndexes()
273+
{
274+
TestSetDifferenceWithBigIndexesHelper(30);
275+
TestSetDifferenceWithBigIndexesHelper(31);
276+
TestSetDifferenceWithBigIndexesHelper(32);
277+
TestSetDifferenceWithBigIndexesHelper(33);
278+
}
279+
DECLARE_UNITTEST(TestSetDifferenceWithBigIndexes);

thrust/system/cuda/detail/dispatch.h

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,26 @@
3434
status = call arguments; \
3535
}
3636

37+
/**
38+
* Dispatch between 32-bit and 64-bit index based versions of the same algorithm
39+
* implementation. This version assumes that callables for both branches consist
40+
* of the same tokens, and is intended to be used with Thrust-style dispatch
41+
* interfaces, that always deduce the size type from the arguments.
42+
*
43+
* This version of the macro supports providing two count variables, which is
44+
* necessary for set algorithms.
45+
*/
46+
#define THRUST_DOUBLE_INDEX_TYPE_DISPATCH(status, call, count1, count2, arguments) \
47+
if (count1 + count2 <= std::numeric_limits<thrust::detail::int32_t>::max()) { \
48+
thrust::detail::int32_t THRUST_PP_CAT2(count1, _fixed) = count1; \
49+
thrust::detail::int32_t THRUST_PP_CAT2(count2, _fixed) = count2; \
50+
status = call arguments; \
51+
} \
52+
else { \
53+
thrust::detail::int64_t THRUST_PP_CAT2(count1, _fixed) = count1; \
54+
thrust::detail::int64_t THRUST_PP_CAT2(count2, _fixed) = count2; \
55+
status = call arguments; \
56+
}
3757
/**
3858
* Dispatch between 32-bit and 64-bit index based versions of the same algorithm
3959
* implementation. This version allows using different token sequences for callables

thrust/system/cuda/detail/set_operations.h

Lines changed: 43 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -50,35 +50,36 @@ namespace __set_operations {
5050

5151
template <bool UpperBound,
5252
class IntT,
53+
class Size,
5354
class It,
5455
class T,
5556
class Comp>
5657
THRUST_DEVICE_FUNCTION void
5758
binary_search_iteration(It data,
58-
int &begin,
59-
int &end,
59+
Size &begin,
60+
Size &end,
6061
T key,
6162
int shift,
6263
Comp comp)
6364
{
6465

6566
IntT scale = (1 << shift) - 1;
66-
int mid = (int)((begin + scale * end) >> shift);
67+
Size mid = (begin + scale * end) >> shift;
6768

6869
T key2 = data[mid];
6970
bool pred = UpperBound ? !comp(key, key2) : comp(key2, key);
7071
if (pred)
71-
begin = (int)mid + 1;
72+
begin = mid + 1;
7273
else
7374
end = mid;
7475
}
7576

76-
template <bool UpperBound, class T, class It, class Comp>
77-
THRUST_DEVICE_FUNCTION int
78-
binary_search(It data, int count, T key, Comp comp)
77+
template <bool UpperBound, class Size, class T, class It, class Comp>
78+
THRUST_DEVICE_FUNCTION Size
79+
binary_search(It data, Size count, T key, Comp comp)
7980
{
80-
int begin = 0;
81-
int end = count;
81+
Size begin = 0;
82+
Size end = count;
8283
while (begin < end)
8384
binary_search_iteration<UpperBound, int>(data,
8485
begin,
@@ -89,12 +90,12 @@ namespace __set_operations {
8990
return begin;
9091
}
9192

92-
template <bool UpperBound, class IntT, class T, class It, class Comp>
93-
THRUST_DEVICE_FUNCTION int
94-
biased_binary_search(It data, int count, T key, IntT levels, Comp comp)
93+
template <bool UpperBound, class IntT, class Size, class T, class It, class Comp>
94+
THRUST_DEVICE_FUNCTION Size
95+
biased_binary_search(It data, Size count, T key, IntT levels, Comp comp)
9596
{
96-
int begin = 0;
97-
int end = count;
97+
Size begin = 0;
98+
Size end = count;
9899

99100
if (levels >= 4 && begin < end)
100101
binary_search_iteration<UpperBound, IntT>(data, begin, end, key, 9, comp);
@@ -110,18 +111,18 @@ namespace __set_operations {
110111
return begin;
111112
}
112113

113-
template <bool UpperBound, class It1, class It2, class Comp>
114-
THRUST_DEVICE_FUNCTION int
115-
merge_path(It1 a, int aCount, It2 b, int bCount, int diag, Comp comp)
114+
template <bool UpperBound, class Size, class It1, class It2, class Comp>
115+
THRUST_DEVICE_FUNCTION Size
116+
merge_path(It1 a, Size aCount, It2 b, Size bCount, Size diag, Comp comp)
116117
{
117118
typedef typename thrust::iterator_traits<It1>::value_type T;
118119

119-
int begin = thrust::max(0, diag - bCount);
120-
int end = thrust::min(diag, aCount);
120+
Size begin = thrust::max<Size>(0, diag - bCount);
121+
Size end = thrust::min<Size>(diag, aCount);
121122

122123
while (begin < end)
123124
{
124-
int mid = (begin + end) >> 1;
125+
Size mid = (begin + end) >> 1;
125126
T aKey = a[mid];
126127
T bKey = b[diag - 1 - mid];
127128
bool pred = UpperBound ? comp(aKey, bKey) : !comp(bKey, aKey);
@@ -134,7 +135,7 @@ namespace __set_operations {
134135
}
135136

136137
template <class It1, class It2, class Size, class Size2, class CompareOp>
137-
pair<Size, Size> THRUST_DEVICE_FUNCTION
138+
THRUST_DEVICE_FUNCTION pair<Size, Size>
138139
balanced_path(It1 keys1,
139140
It2 keys2,
140141
Size num_keys1,
@@ -434,7 +435,7 @@ namespace __set_operations {
434435
CompareOp compare_op;
435436
SetOp set_op;
436437
pair<Size, Size> *partitions;
437-
Size *output_count;
438+
std::size_t *output_count;
438439

439440
//---------------------------------------------------------------------
440441
// Utility functions
@@ -756,7 +757,7 @@ namespace __set_operations {
756757
CompareOp compare_op_,
757758
SetOp set_op_,
758759
pair<Size, Size> *partitions_,
759-
Size *output_count_)
760+
std::size_t * output_count_)
760761
: storage(storage_),
761762
tile_state(tile_state_),
762763
keys1_in(core::make_load_iterator(ptx_plan(), keys1_)),
@@ -801,7 +802,7 @@ namespace __set_operations {
801802
CompareOp compare_op,
802803
SetOp set_op,
803804
pair<Size, Size> *partitions,
804-
Size * output_count,
805+
std::size_t * output_count,
805806
ScanTileState tile_state,
806807
char * shmem)
807808
{
@@ -1124,7 +1125,7 @@ namespace __set_operations {
11241125
Size num_keys2,
11251126
KeysOutputIt keys_output,
11261127
ValuesOutputIt values_output,
1127-
Size * output_count,
1128+
std::size_t * output_count,
11281129
CompareOp compare_op,
11291130
SetOp set_op,
11301131
cudaStream_t stream,
@@ -1167,7 +1168,7 @@ namespace __set_operations {
11671168
Size num_tiles = (keys_total + tile_size - 1) / tile_size;
11681169

11691170
size_t tile_agent_storage;
1170-
status = ScanTileState::AllocationSize(static_cast<int>(num_tiles), tile_agent_storage);
1171+
status = ScanTileState::AllocationSize(num_tiles, tile_agent_storage);
11711172
CUDA_CUB_RET_IF_FAIL(status);
11721173

11731174
size_t vshmem_storage = core::vshmem_size(set_op_plan.shared_memory_size,
@@ -1191,7 +1192,7 @@ namespace __set_operations {
11911192
}
11921193

11931194
ScanTileState tile_state;
1194-
status = tile_state.Init(static_cast<int>(num_tiles), allocations[0], allocation_sizes[0]);
1195+
status = tile_state.Init(num_tiles, allocations[0], allocation_sizes[0]);
11951196
CUDA_CUB_RET_IF_FAIL(status);
11961197

11971198
pair<Size, Size> *partitions = (pair<Size, Size> *)allocations[1];
@@ -1268,24 +1269,25 @@ namespace __set_operations {
12681269
bool debug_sync = THRUST_DEBUG_SYNC_FLAG;
12691270

12701271
cudaError_t status;
1271-
status = doit_step<HAS_VALUES>(NULL,
1272+
THRUST_DOUBLE_INDEX_TYPE_DISPATCH(status, doit_step<HAS_VALUES>,
1273+
num_keys1, num_keys2, (NULL,
12721274
temp_storage_bytes,
12731275
keys1_first,
12741276
keys2_first,
12751277
values1_first,
12761278
values2_first,
1277-
num_keys1,
1278-
num_keys2,
1279+
num_keys1_fixed,
1280+
num_keys2_fixed,
12791281
keys_output,
12801282
values_output,
1281-
reinterpret_cast<size_type*>(NULL),
1283+
reinterpret_cast<std::size_t*>(NULL),
12821284
compare_op,
12831285
set_op,
12841286
stream,
1285-
debug_sync);
1287+
debug_sync));
12861288
cuda_cub::throw_on_error(status, "set_operations failed on 1st step");
12871289

1288-
size_t allocation_sizes[2] = {sizeof(size_type), temp_storage_bytes};
1290+
size_t allocation_sizes[2] = {sizeof(std::size_t), temp_storage_bytes};
12891291
void * allocations[2] = {NULL, NULL};
12901292

12911293
size_t storage_size = 0;
@@ -1307,30 +1309,31 @@ namespace __set_operations {
13071309
allocation_sizes);
13081310
cuda_cub::throw_on_error(status, "set_operations failed on 2nd alias_storage");
13091311

1310-
size_type* d_output_count
1311-
= thrust::detail::aligned_reinterpret_cast<size_type*>(allocations[0]);
1312+
std::size_t* d_output_count
1313+
= thrust::detail::aligned_reinterpret_cast<std::size_t*>(allocations[0]);
13121314

1313-
status = doit_step<HAS_VALUES>(allocations[1],
1315+
THRUST_DOUBLE_INDEX_TYPE_DISPATCH(status, doit_step<HAS_VALUES>,
1316+
num_keys1, num_keys2, (allocations[1],
13141317
temp_storage_bytes,
13151318
keys1_first,
13161319
keys2_first,
13171320
values1_first,
13181321
values2_first,
1319-
num_keys1,
1320-
num_keys2,
1322+
num_keys1_fixed,
1323+
num_keys2_fixed,
13211324
keys_output,
13221325
values_output,
13231326
d_output_count,
13241327
compare_op,
13251328
set_op,
13261329
stream,
1327-
debug_sync);
1330+
debug_sync));
13281331
cuda_cub::throw_on_error(status, "set_operations failed on 2nd step");
13291332

13301333
status = cuda_cub::synchronize(policy);
13311334
cuda_cub::throw_on_error(status, "set_operations failed to synchronize");
13321335

1333-
size_type output_count = cuda_cub::get_value(policy, d_output_count);
1336+
std::size_t output_count = cuda_cub::get_value(policy, d_output_count);
13341337

13351338
return thrust::make_pair(keys_output + output_count, values_output + output_count);
13361339
}

0 commit comments

Comments
 (0)