@@ -12,27 +12,28 @@ const ACC_FLAG_P::UInt8 = 1 # Only current block's prefix available
1212end
1313
1414
15- @kernel cpu = false inbounds = true unsafe_indices = true function _accumulate_block! (
15+ function _accumulate_block! (
1616 op, v, init, neutral,
1717 inclusive,
1818 flags, prefixes, # one per block
19- )
19+ :: Val{block_size}
20+ ) where block_size
21+ @inbounds begin
2022 # NOTE: shmem_size MUST be greater than 2 * block_size
2123 # NOTE: block_size MUST be a power of 2
2224 len = length (v)
23- @uniform block_size = @groupsize ()[1 ]
24- temp = @localmem eltype (v) (0x2 * block_size + conflict_free_offset (0x2 * block_size),)
25+ temp = KI. localmemory (eltype (v), 0x2 * block_size + conflict_free_offset (0x2 * block_size))
2526
2627 # NOTE: for many index calculations in this library, computation using zero-indexing leads to
2728 # fewer operations (also code is transpiled to CUDA / ROCm / oneAPI / Metal code which do zero
2829 # indexing). Internal calculations will be done using zero indexing except when actually
2930 # accessing memory. As with C, the lower bound is inclusive, the upper bound exclusive.
3031
3132 # Group (block) and local (thread) indices
32- iblock = @index (Group, Linear) - 0x1
33- ithread = @index (Local, Linear) - 0x1
33+ iblock = KI . get_group_id () . x - 0x1
34+ ithread = KI . get_local_id () . x - 0x1
3435
35- num_blocks = @ndrange ()[ 1 ] ÷ block_size
36+ num_blocks = KI . get_num_groups () . x
3637 block_offset = iblock * block_size * 0x2 # Processing two elements per thread
3738
3839 # Copy two elements from the main array; offset indices to avoid bank conflicts
5960 next_pow2 = block_size * 0x2
6061 d = next_pow2 >> 0x1
6162 while d > 0x0 # TODO : unroll this like in reduce.jl ?
62- @synchronize ()
63+ KI . barrier ()
6364
6465 if ithread < d
6566 _ai = offset * (0x2 * ithread + 0x1 ) - 0x1
8485 d = typeof (ithread)(1 )
8586 while d < next_pow2
8687 offset = offset >> 0x1
87- @synchronize ()
88+ KI . barrier ()
8889
8990 if ithread < d
9091 _ai = offset * (0x2 * ithread + 0x1 ) - 0x1
@@ -103,10 +104,10 @@ end
103104 # Later blocks should always be inclusively-scanned
104105 if inclusive || iblock != 0x0
105106 # To compute an inclusive scan, shift elements left...
106- @synchronize ()
107+ KI . barrier ()
107108 t1 = temp[ai + bank_offset_a + 0x1 ]
108109 t2 = temp[bi + bank_offset_b + 0x1 ]
109- @synchronize ()
110+ KI . barrier ()
110111
111112 if ai > 0x0
112113 temp[ai - 0x1 + conflict_free_offset (ai - 0x1 ) + 0x1 ] = t1
123124 end
124125 end
125126
126- @synchronize ()
127+ KI . barrier ()
127128
128129 # Write this block's final prefix to global array and set flag to "block prefix computed"
129130 if bi == 0x2 * block_size - 0x1
@@ -145,24 +146,25 @@ end
145146 if block_offset + bi < len
146147 v[block_offset + bi + 0x1 ] = temp[bi + bank_offset_b + 0x1 ]
147148 end
149+ end
150+ nothing
148151end
149152
150153
151- @kernel cpu = false inbounds = true unsafe_indices = true function _accumulate_previous! (
152- op, v, flags, @Const ( prefixes),
153- )
154-
154+ function _accumulate_previous! (
155+ op, v, flags, prefixes, :: Val{block_size}
156+ ) where block_size
157+ @inbounds begin
155158 len = length (v)
156- block_size = @groupsize ()[1 ]
157159
158160 # NOTE: for many index calculations in this library, computation using zero-indexing leads to
159161 # fewer operations (also code is transpiled to CUDA / ROCm / oneAPI / Metal code which do zero
160162 # indexing). Internal calculations will be done using zero indexing except when actually
161163 # accessing memory. As with C, the lower bound is inclusive, the upper bound exclusive.
162164
163165 # Group (block) and local (thread) indices
164- iblock = @index (Group, Linear) - 0x1 + 0x1 # Skipping first block
165- ithread = @index (Local, Linear) - 0x1
166+ iblock = KI . get_group_id () . x - 0x1 + 0x1 # Skipping first block
167+ ithread = KI . get_local_id () . x - 0x1
166168 block_offset = iblock * block_size * 0x2 # Processing two elements per thread
167169
168170 # Each block looks back to find running prefix sum
197199 # There are two synchronization concerns here:
198200 # 1. Withing a group we want to ensure that all writed to `v` have occured before setting the flag.
199201 # 2. Between groups we need to use a fence and atomic load/store to ensure that memory operations are not re-ordered
200- @synchronize () # within-block
202+ KI . barrier () # within-block
201203 # Note: This fence is needed to ensure that the flag is not set before copying into v.
202204 # See https://doc.rust-lang.org/std/sync/atomic/fn.fence.html
203205 # for more details.
@@ -206,24 +208,26 @@ end
206208 if ithread == 0x0
207209 UnsafeAtomics. store! (pointer (flags, iblock + 0x1 ), convert (eltype (flags), ACC_FLAG_A), UnsafeAtomics. monotonic)
208210 end
211+ end
212+ nothing
209213end
210214
211215
212- @kernel cpu= false inbounds= true unsafe_indices= true function _accumulate_previous_coupled_preblocks! (
213- op, v, prefixes,
214- )
216+ function _accumulate_previous_coupled_preblocks! (
217+ op, v, prefixes, :: Val{block_size}
218+ ) where block_size
219+ @inbounds begin
215220 # No decoupled lookback
216221 len = length (v)
217- block_size = @groupsize ()[1 ]
218222
219223 # NOTE: for many index calculations in this library, computation using zero-indexing leads to
220224 # fewer operations (also code is transpiled to CUDA / ROCm / oneAPI / Metal code which do zero
221225 # indexing). Internal calculations will be done using zero indexing except when actually
222226 # accessing memory. As with C, the lower bound is inclusive, the upper bound exclusive.
223227
224228 # Group (block) and local (thread) indices
225- iblock = @index (Group, Linear) - 0x1 + 0x1 # Skipping first block
226- ithread = @index (Local, Linear) - 0x1
229+ iblock = KI . get_group_id () . x - 0x1 + 0x1 # Skipping first block
230+ ithread = KI . get_local_id () . x - 0x1
227231 block_offset = iblock * block_size * 0x2 # Processing two elements per thread
228232
229233 # Each block looks back to find running prefix sum
250254 if block_offset + bi < len
251255 v[block_offset + bi + 0x1 ] = op (running_prefix, v[block_offset + bi + 0x1 ])
252256 end
257+ end
258+ nothing
253259end
254260
255261
@@ -298,14 +304,10 @@ function accumulate_1d_gpu!(
298304 flags = temp_flags
299305 end
300306
301- kernel1! = _accumulate_block! (backend, block_size)
302- kernel1! (op, v, init, neutral, inclusive, flags, prefixes,
303- ndrange= num_blocks * block_size)
307+ KI. @kernel backend workgroupsize= block_size numworkgroups= num_blocks _accumulate_block! (op, v, init, neutral, inclusive, flags, prefixes, Val (block_size))
304308
305309 if num_blocks > 1
306- kernel2! = _accumulate_previous! (backend, block_size)
307- kernel2! (op, v, flags, prefixes,
308- ndrange= (num_blocks - 1 ) * block_size)
310+ KI. @kernel backend workgroupsize= block_size numworkgroups= (num_blocks- 1 ) _accumulate_previous! (op, v, flags, prefixes, Val (block_size))
309311 end
310312
311313 return v
@@ -349,22 +351,17 @@ function accumulate_1d_gpu!(
349351 prefixes = temp
350352 end
351353
352- kernel1! = _accumulate_block! (backend, block_size)
353- kernel1! (op, v, init, neutral, inclusive, nothing , prefixes,
354- ndrange= num_blocks * block_size)
354+ KI. @kernel backend workgroupsize= block_size numworkgroups= num_blocks _accumulate_block! (op, v, init, neutral, inclusive, nothing , prefixes, Val (block_size))
355355
356356 if num_blocks > 1
357357
358358 # Accumulate prefixes of all blocks; use neutral as init here to not reinclude init
359359 num_blocks_prefixes = (length (prefixes) + elems_per_block - 1 ) ÷ elems_per_block
360- kernel1! (op, prefixes, neutral, neutral, true , nothing , nothing ,
361- ndrange= num_blocks_prefixes * block_size)
360+ KI. @kernel backend workgroupsize= block_size numworkgroups= num_blocks_prefixes _accumulate_block! (op, prefixes, neutral, neutral, true , nothing , nothing , Val (block_size))
362361
363362 # Prefixes are pre-accumulated (completely accumulated if num_blocks_prefixes == 1, or
364363 # partially, which we will account for in the coupled lookback)
365- kernel2! = _accumulate_previous_coupled_preblocks! (backend, block_size)
366- kernel2! (op, v, prefixes,
367- ndrange= (num_blocks - 1 ) * block_size)
364+ KI. @kernel backend workgroupsize= block_size numworkgroups= (num_blocks- 1 ) _accumulate_previous_coupled_preblocks! (op, v, prefixes, Val (block_size))
368365 end
369366
370367 return v
0 commit comments