@@ -5,31 +5,29 @@ include(joinpath(dirname(pathof(KernelAbstractions)), "../examples/utils.jl")) #
55
66# Function to use as a baseline for CPU metrics
77function create_histogram (input)
8- histogram_output = zeros (Int , maximum (input))
8+ histogram_output = zeros (eltype (input) , maximum (input))
99 for i in input
1010 histogram_output[i] += 1
1111 end
1212 return histogram_output
1313end
1414
1515# This a 1D histogram kernel where the histogramming happens on shmem
16- @kernel function histogram_kernel! (histogram_output, input)
17- tid = @index (Global , Linear)
16+ @kernel unsafe_indices = true function histogram_kernel! (histogram_output, input)
17+ gid = @index (Group , Linear)
1818 lid = @index (Local, Linear)
1919
20- @uniform warpsize = Int (32 )
21-
22- @uniform gs = @groupsize ()[1 ]
20+ @uniform gs = prod (@groupsize ())
21+ tid = (gid - 1 ) * gs + lid
2322 @uniform N = length (histogram_output)
2423
25- shared_histogram = @localmem Int (gs)
24+ shared_histogram = @localmem eltype (input) (gs)
2625
2726 # This will go through all input elements and assign them to a location in
2827 # shmem. Note that if there is not enough shem, we create different shmem
2928 # blocks to write to. For example, if shmem is of size 256, but it's
3029 # possible to get a value of 312, then we will have 2 separate shmem blocks,
3130 # one from 1->256, and another from 256->512
32- @uniform max_element = 1
3331 for min_element in 1 : gs: N
3432
3533 # Setting shared_histogram to 0
4240 end
4341
4442 # Defining bin on shared memory and writing to it if possible
45- bin = input[tid]
43+ bin = tid <= length ( input) ? input [tid] : 0
4644 if bin >= min_element && bin < max_element
4745 bin -= min_element - 1
4846 @atomic shared_histogram[bin] += 1
5856
5957end
6058
61- function histogram! (histogram_output, input)
59+ function histogram! (histogram_output, input, groupsize = 256 )
6260 backend = get_backend (histogram_output)
6361 # Need static block size
64- kernel! = histogram_kernel! (backend, (256 ,))
62+ kernel! = histogram_kernel! (backend, (groupsize ,))
6563 kernel! (histogram_output, input, ndrange = size (input))
6664 return
6765end
@@ -74,9 +72,10 @@ function move(backend, input)
7472end
7573
7674@testset " histogram tests" begin
77- rand_input = [rand (1 : 128 ) for i in 1 : 1000 ]
78- linear_input = [i for i in 1 : 1024 ]
79- all_two = [2 for i in 1 : 512 ]
75+ # Use Int32 as some backends don't support 64-bit atomics
76+ rand_input = Int32 .(rand (1 : 128 , 1000 ))
77+ linear_input = Int32 .(1 : 1024 )
78+ all_two = fill (Int32 (2 ), 512 )
8079
8180 histogram_rand_baseline = create_histogram (rand_input)
8281 histogram_linear_baseline = create_histogram (linear_input)
8685 linear_input = move (backend, linear_input)
8786 all_two = move (backend, all_two)
8887
89- rand_histogram = KernelAbstractions. zeros (backend, Int, 128 )
90- linear_histogram = KernelAbstractions. zeros (backend, Int, 1024 )
91- two_histogram = KernelAbstractions. zeros (backend, Int, 2 )
88+ rand_histogram = KernelAbstractions. zeros (backend, eltype (rand_input), Int ( maximum (rand_input)) )
89+ linear_histogram = KernelAbstractions. zeros (backend, eltype (linear_input), Int ( maximum (linear_input)) )
90+ two_histogram = KernelAbstractions. zeros (backend, eltype (all_two), Int ( maximum (all_two)) )
9291
93- histogram! (rand_histogram, rand_input)
92+ histogram! (rand_histogram, rand_input, 6 )
9493 histogram! (linear_histogram, linear_input)
9594 histogram! (two_histogram, all_two)
96- KernelAbstractions. synchronize (CPU () )
95+ KernelAbstractions. synchronize (backend )
9796
9897 @test isapprox (Array (rand_histogram), histogram_rand_baseline)
9998 @test isapprox (Array (linear_histogram), histogram_linear_baseline)
0 commit comments