diff --git a/src/ScopedValues.jl b/src/ScopedValues.jl index a188d01..87c1c39 100644 --- a/src/ScopedValues.jl +++ b/src/ScopedValues.jl @@ -156,4 +156,6 @@ end end # isdefined +include("reducers.jl") + end # module ScopedValues diff --git a/src/reducers.jl b/src/reducers.jl new file mode 100644 index 0000000..a2e276b --- /dev/null +++ b/src/reducers.jl @@ -0,0 +1,25 @@ +mutable struct Reducer{T, Op, Init} + @atomic value::T + const op::Op + const init::Init +end +Reducer(op::Op, init::Init) where {Op, Init} = Reducer(init(), op, init) + +Base.getindex(r::Reducer) = r.value +function Base.setindex!(r::Reducer{T, Op}, val::T) where {T, Op} + _, new = @atomic r.value r.op val + new +end + +split(r::Reducer{T, Op, Init}) where {T, Op, Init} = Reducer(r.init()::T, r.op, r.init) +function join!(r::Reducer{T, Op, Init}, other_r::Reducer{T, Op, Init}) where {T, Op, Init} + r[] = other_r[] +end + +function split(f, val::ScopedValue{<:Reducer}) + reducer = split(val[]) + @scoped val => reducer begin + f() + end + join!(val[], reducer) +end \ No newline at end of file diff --git a/test/reducers.jl b/test/reducers.jl new file mode 100644 index 0000000..eafabcd --- /dev/null +++ b/test/reducers.jl @@ -0,0 +1,30 @@ +import Base.Threads: @spawn +import ScopedValues: Reducer, split, join! + +function splitting_reduce(op::Op, arr::Array{T}, initf=()->zero(T); grainsize=16) where {T, Op} + if length(arr) <= grainsize + return reduce(op, arr; init=initf()::T) + end + reducer = ScopedValue(Reducer(op, initf)) + + function reduce_impl(arr, range, grainsize) + if length(range) <= grainsize + reducer[][] = reduce(op, view(arr, range); init=initf()::T) + return + end + midpoint = length(range) รท 2 + range_a = first(range):(first(range)+midpoint) + range_b = (first(range)+midpoint+1):last(range) + + @sync begin + @spawn split(reducer) do + reduce_impl(arr, range_a, grainsize) + end + reduce_impl(arr, range_b, grainsize) + end + end + reduce_impl(arr, 1:length(arr), grainsize) + return reducer[][] +end + +@test splitting_reduce(+, ones(1024)) == 1024 \ No newline at end of file diff --git a/test/runtests.jl b/test/runtests.jl index 7a0e912..a6ba8cc 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -120,3 +120,5 @@ end end end +include("reducers.jl") +