@@ -30,6 +30,28 @@ __global__ void Sort(int *data) {
3030 BlockStore (temp_storage_store).Store (data, thread_keys);
3131}
3232
33+ __global__ void SortHalf (__half *data) {
34+ // CHECK: using BlockRadixSort = dpct::group::group_radix_sort<sycl::half, 4>;
35+ // CHECK-NEXT: using BlockLoad = dpct::group::group_load<sycl::half, 4>;
36+ // CHECK-NEXT: using BlockStore = dpct::group::group_store<sycl::half, 4>;
37+ // CHECK-NOT: __shared__ typename BlockLoad::TempStorage temp_storage_load;
38+ // CHECK-NOT: __shared__ typename BlockStore::TempStorage temp_storage_store;
39+ // CHECK-NOT: __shared__ typename BlockRadixSort::TempStorage temp_storage;
40+ using BlockRadixSort = cub::BlockRadixSort<__half, 128 , 4 >;
41+ using BlockLoad = cub::BlockLoad<__half, 128 , 4 >;
42+ using BlockStore = cub::BlockStore<__half, 128 , 4 >;
43+ __shared__ typename BlockLoad::TempStorage temp_storage_load;
44+ __shared__ typename BlockStore::TempStorage temp_storage_store;
45+ __shared__ typename BlockRadixSort::TempStorage temp_storage;
46+ __half thread_keys[4 ];
47+ // CHECK: BlockLoad(temp_storage_load).load(item_ct1, data, thread_keys);
48+ // CHECK-NEXT: BlockRadixSort(temp_storage).sort(item_ct1, thread_keys);
49+ // CHECK-NEXT: BlockStore(temp_storage_store).store(item_ct1, data, thread_keys);
50+ BlockLoad (temp_storage_load).Load (data, thread_keys);
51+ BlockRadixSort (temp_storage).Sort (thread_keys);
52+ BlockStore (temp_storage_store).Store (data, thread_keys);
53+ }
54+
3355__global__ void SortDescending (int *data) {
3456 // CHECK: using BlockRadixSort = dpct::group::group_radix_sort<int, 4>;
3557 // CHECK-NEXT: using BlockLoad = dpct::group::group_load<int, 4, dpct::group::group_load_algorithm::blocked>;
@@ -171,8 +193,9 @@ __global__ void test_unsupported(int *data) {
171193
172194template <typename T, int N>
173195void print_array (T (&arr)[N]) {
174- for (int i = 0 ; i < N; ++i)
175- printf (" %d%c" , arr[i], (i == N - 1 ? ' \n ' : ' ,' ));
196+ for (int i = 0 ; i < N; ++i) {
197+ std::cout << (int )arr[i] << (i == N - 1 ? ' \n ' : ' ,' );
198+ }
176199}
177200
178201bool test_sort () {
@@ -211,6 +234,42 @@ bool test_sort() {
211234 return true ;
212235}
213236
237+ bool test_sorthalf () {
238+ __half data[512 ] = {0 }, *d_data = nullptr ;
239+ cudaMalloc (&d_data, sizeof (__half) * 512 );
240+ for (int i = 0 , x = 0 , y = 511 ; i < 128 ; ++i) {
241+ data[i * 4 + 0 ] = x++;
242+ data[i * 4 + 1 ] = y--;
243+ data[i * 4 + 2 ] = x++;
244+ data[i * 4 + 3 ] = y--;
245+ }
246+ cudaMemcpy (d_data, data, sizeof (data), cudaMemcpyHostToDevice);
247+ // CHECK: q_ct1.submit(
248+ // CHECK-NEXT: [&](sycl::handler &cgh) {
249+ // CHECK-NEXT: sycl::local_accessor<uint8_t, 1> temp_storage_load_acc(dpct::group::group_load<sycl::half, 4>::get_local_memory_size(sycl::range<3>(1, 1, 128).size()), cgh);
250+ // CHECK-NEXT: sycl::local_accessor<uint8_t, 1> temp_storage_store_acc(dpct::group::group_store<sycl::half, 4>::get_local_memory_size(sycl::range<3>(1, 1, 128).size()), cgh);
251+ // CHECK-NEXT: sycl::local_accessor<uint8_t, 1> temp_storage_acc(dpct::group::group_radix_sort<sycl::half, 4>::get_local_memory_size(sycl::range<3>(1, 1, 128).size()), cgh);
252+ // CHECK-EMPTY:
253+ // CHECK-NEXT: cgh.parallel_for(
254+ // CHECK-NEXT: sycl::nd_range<3>(sycl::range<3>(1, 1, 128), sycl::range<3>(1, 1, 128)),
255+ // CHECK-NEXT: [=](sycl::nd_item<3> item_ct1) {
256+ // CHECK-NEXT: SortHalf(d_data, &temp_storage_load_acc[0], &temp_storage_store_acc[0], &temp_storage_acc[0]);
257+ // CHECK-NEXT: });
258+ // CHECK-NEXT: });
259+ SortHalf<<<1 , 128 >>> (d_data);
260+ cudaDeviceSynchronize ();
261+ cudaMemcpy (data, d_data, sizeof (data), cudaMemcpyDeviceToHost);
262+ cudaFree (d_data);
263+ for (int i = 0 ; i < 512 ; ++i)
264+ if ((int )data[i] != i) {
265+ printf (" test_sorthalf failed\n " );
266+ print_array (data);
267+ return false ;
268+ }
269+ printf (" test_sorthalf pass\n " );
270+ return true ;
271+ }
272+
214273bool test_sort_descending () {
215274 int data[512 ] = {0 }, *d_data = nullptr ;
216275 cudaMalloc (&d_data, sizeof (int ) * 512 );
@@ -610,7 +669,7 @@ bool test_sort_descending_blocked_to_striped_bit() {
610669}
611670
612671int main () {
613- return !(test_sort () && test_sort_descending () &&
672+ return !(test_sort () && test_sorthalf () && test_sort_descending () &&
614673 test_sort_blocked_to_striped () &&
615674 test_sort_descending_blocked_to_striped () && test_sort_bit () &&
616675 test_sort_descending_bit () && test_sort_blocked_to_striped_bit () &&
0 commit comments