@@ -74,34 +74,30 @@ extern "C" int _system_poseidon2_tracegen(
7474 return cudaGetLastError ();
7575}
7676
77- // Reduces the records, removing duplicates and storing the number of times
78- // each occurs in d_counts. The number of records after reduction is stored
79- // into host pointer num_records.
80- extern " C" int _system_poseidon2_deduplicate_records (
77+ // Prepares d_num_records for use with sort reduce and stores the temporary buffer
78+ // size necessary for both cub functions (i.e. sort and reduce).
79+ extern " C" int _system_poseidon2_deduplicate_records_get_temp_bytes (
8180 Fp *d_records,
8281 uint32_t *d_counts,
83- size_t *num_records
82+ size_t num_records,
83+ size_t *d_num_records,
84+ size_t *h_temp_bytes_out
8485) {
85- auto [grid, block] = kernel_launch_params (* num_records);
86+ auto [grid, block] = kernel_launch_params (num_records);
8687 FpArray<16 > *d_records_fp16 = reinterpret_cast <FpArray<16 > *>(d_records);
87- size_t *d_num_records;
8888
8989 // We want to sort and reduce the raw records, keeping track of how many
90- // each occurs in d_counts. To prepare for reduce we need to a) allocate
91- // d_num_records, b) fill d_counts with 1s, and c) group keys together
92- // using sort.
93- cudaMallocAsync (&d_num_records, sizeof (size_t ), cudaStreamPerThread);
94- cudaMemcpyAsync (
95- d_num_records, num_records, sizeof (size_t ), cudaMemcpyHostToDevice, cudaStreamPerThread
96- );
97- fill_buffer<uint32_t ><<<grid, block, 0 , cudaStreamPerThread>>> (d_counts, 1 , *num_records);
90+ // each occurs in d_counts. To prepare for reduce we need to a) fill
91+ // d_counts with 1s, and b) group keys together using sort. Note we do
92+ // b) in the kernel below.
93+ fill_buffer<uint32_t ><<<grid, block>>> (d_counts, 1 , num_records);
9894
9995 size_t sort_storage_bytes = 0 ;
10096 cub::DeviceMergeSort::SortKeys (
10197 nullptr ,
10298 sort_storage_bytes,
10399 d_records_fp16,
104- * num_records,
100+ num_records,
105101 Fp16CompareOp (),
106102 cudaStreamPerThread
107103 );
@@ -116,13 +112,27 @@ extern "C" int _system_poseidon2_deduplicate_records(
116112 d_counts,
117113 d_num_records,
118114 std::plus (),
119- * num_records,
115+ num_records,
120116 cudaStreamPerThread
121117 );
122118
123- size_t temp_storage_bytes = std::max (sort_storage_bytes, reduce_storage_bytes);
124- void *d_temp_storage = nullptr ;
125- cudaMallocAsync (&d_temp_storage, temp_storage_bytes, cudaStreamPerThread);
119+ *h_temp_bytes_out = std::max (sort_storage_bytes, reduce_storage_bytes);
120+ return cudaGetLastError ();
121+ }
122+
123+ // Reduces the records, removing duplicates and storing the number of times
124+ // each occurs in d_counts. The number of records after reduction is stored
125+ // into host pointer num_records. The value of temp_storage_bytes should be
126+ // computed using _system_poseidon2_deduplicate_records_get_temp_bytes.
127+ extern " C" int _system_poseidon2_deduplicate_records (
128+ Fp *d_records,
129+ uint32_t *d_counts,
130+ size_t num_records,
131+ size_t *d_num_records,
132+ void *d_temp_storage,
133+ size_t temp_storage_bytes
134+ ) {
135+ FpArray<16 > *d_records_fp16 = reinterpret_cast <FpArray<16 > *>(d_records);
126136
127137 // TODO: We currently can't use DeviceRadixSort since each key is 64 bytes
128138 // which causes Fp16Decomposer usage to exceed shared memory. We need to
@@ -131,7 +141,7 @@ extern "C" int _system_poseidon2_deduplicate_records(
131141 d_temp_storage,
132142 temp_storage_bytes,
133143 d_records_fp16,
134- * num_records,
144+ num_records,
135145 Fp16CompareOp (),
136146 cudaStreamPerThread
137147 );
@@ -148,14 +158,9 @@ extern "C" int _system_poseidon2_deduplicate_records(
148158 d_counts,
149159 d_num_records,
150160 std::plus (),
151- * num_records,
161+ num_records,
152162 cudaStreamPerThread
153163 );
154164
155- cudaMemcpyAsync (
156- num_records, d_num_records, sizeof (size_t ), cudaMemcpyDeviceToHost, cudaStreamPerThread
157- );
158- cudaFreeAsync (d_num_records, cudaStreamPerThread);
159- cudaFreeAsync (d_temp_storage, cudaStreamPerThread);
160165 return cudaGetLastError ();
161166}
0 commit comments