|
3 | 3 | #include <thrust/functional.h> |
4 | 4 | #include <thrust/iterator/discard_iterator.h> |
5 | 5 | #include <thrust/iterator/retag.h> |
| 6 | +#include <thrust/device_malloc.h> |
| 7 | +#include <thrust/device_free.h> |
6 | 8 |
|
7 | 9 |
|
8 | 10 | template<typename T> |
@@ -555,3 +557,93 @@ void TestInclusiveScanWithIndirection(void) |
555 | 557 | } |
556 | 558 | DECLARE_INTEGRAL_VECTOR_UNITTEST(TestInclusiveScanWithIndirection); |
557 | 559 |
|
| 560 | +struct only_set_when_expected_it |
| 561 | +{ |
| 562 | + long long expected; |
| 563 | + bool * flag; |
| 564 | + |
| 565 | + __host__ __device__ only_set_when_expected_it operator++() const { return *this; } |
| 566 | + __host__ __device__ only_set_when_expected_it operator*() const { return *this; } |
| 567 | + template<typename Difference> |
| 568 | + __host__ __device__ only_set_when_expected_it operator+(Difference) const { return *this; } |
| 569 | + template<typename Index> |
| 570 | + __host__ __device__ only_set_when_expected_it operator[](Index) const { return *this; } |
| 571 | + |
| 572 | + __device__ |
| 573 | + void operator=(long long value) const |
| 574 | + { |
| 575 | + if (value == expected) |
| 576 | + { |
| 577 | + *flag = true; |
| 578 | + } |
| 579 | + } |
| 580 | +}; |
| 581 | + |
| 582 | +namespace thrust |
| 583 | +{ |
| 584 | +template<> |
| 585 | +struct iterator_traits<only_set_when_expected_it> |
| 586 | +{ |
| 587 | + typedef long long value_type; |
| 588 | + typedef only_set_when_expected_it reference; |
| 589 | +}; |
| 590 | +} |
| 591 | + |
| 592 | +void TestInclusiveScanWithBigIndexesHelper(int magnitude) |
| 593 | +{ |
| 594 | + thrust::constant_iterator<long long> begin(1); |
| 595 | + thrust::constant_iterator<long long> end = begin + (1ll << magnitude); |
| 596 | + ASSERT_EQUAL(thrust::distance(begin, end), 1ll << magnitude); |
| 597 | + |
| 598 | + thrust::device_ptr<bool> has_executed = thrust::device_malloc<bool>(1); |
| 599 | + *has_executed = false; |
| 600 | + |
| 601 | + only_set_when_expected_it out = { (1ll << magnitude), thrust::raw_pointer_cast(has_executed) }; |
| 602 | + |
| 603 | + thrust::inclusive_scan(thrust::device, begin, end, out); |
| 604 | + |
| 605 | + bool has_executed_h = *has_executed; |
| 606 | + thrust::device_free(has_executed); |
| 607 | + |
| 608 | + ASSERT_EQUAL(has_executed_h, true); |
| 609 | +} |
| 610 | + |
| 611 | +void TestInclusiveScanWithBigIndexes() |
| 612 | +{ |
| 613 | + TestInclusiveScanWithBigIndexesHelper(30); |
| 614 | + TestInclusiveScanWithBigIndexesHelper(31); |
| 615 | + TestInclusiveScanWithBigIndexesHelper(32); |
| 616 | + TestInclusiveScanWithBigIndexesHelper(33); |
| 617 | +} |
| 618 | + |
| 619 | +DECLARE_UNITTEST(TestInclusiveScanWithBigIndexes); |
| 620 | + |
| 621 | +void TestExclusiveScanWithBigIndexesHelper(int magnitude) |
| 622 | +{ |
| 623 | + thrust::constant_iterator<long long> begin(1); |
| 624 | + thrust::constant_iterator<long long> end = begin + (1ll << magnitude); |
| 625 | + ASSERT_EQUAL(thrust::distance(begin, end), 1ll << magnitude); |
| 626 | + |
| 627 | + thrust::device_ptr<bool> has_executed = thrust::device_malloc<bool>(1); |
| 628 | + *has_executed = false; |
| 629 | + |
| 630 | + only_set_when_expected_it out = { (1ll << magnitude) - 1, thrust::raw_pointer_cast(has_executed) }; |
| 631 | + |
| 632 | + thrust::exclusive_scan(thrust::device, begin, end, out,0ll); |
| 633 | + |
| 634 | + bool has_executed_h = *has_executed; |
| 635 | + thrust::device_free(has_executed); |
| 636 | + |
| 637 | + ASSERT_EQUAL(has_executed_h, true); |
| 638 | +} |
| 639 | + |
| 640 | +void TestExclusiveScanWithBigIndexes() |
| 641 | +{ |
| 642 | + TestExclusiveScanWithBigIndexesHelper(30); |
| 643 | + TestExclusiveScanWithBigIndexesHelper(31); |
| 644 | + TestExclusiveScanWithBigIndexesHelper(32); |
| 645 | + TestExclusiveScanWithBigIndexesHelper(33); |
| 646 | +} |
| 647 | + |
| 648 | +DECLARE_UNITTEST(TestExclusiveScanWithBigIndexes); |
| 649 | + |
0 commit comments