Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/ScopedValues.jl
Original file line number Diff line number Diff line change
Expand Up @@ -156,4 +156,6 @@ end

end # isdefined

include("reducers.jl")

end # module ScopedValues
25 changes: 25 additions & 0 deletions src/reducers.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
mutable struct Reducer{T, Op, Init}
@atomic value::T
const op::Op
const init::Init
end
Reducer(op::Op, init::Init) where {Op, Init} = Reducer(init(), op, init)

Base.getindex(r::Reducer) = r.value
function Base.setindex!(r::Reducer{T, Op}, val::T) where {T, Op}
_, new = @atomic r.value r.op val
new
end

split(r::Reducer{T, Op, Init}) where {T, Op, Init} = Reducer(r.init()::T, r.op, r.init)
function join!(r::Reducer{T, Op, Init}, other_r::Reducer{T, Op, Init}) where {T, Op, Init}
r[] = other_r[]
end

function split(f, val::ScopedValue{<:Reducer})
reducer = split(val[])
@scoped val => reducer begin
f()
end
join!(val[], reducer)
end
30 changes: 30 additions & 0 deletions test/reducers.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import Base.Threads: @spawn
import ScopedValues: Reducer, split, join!

function splitting_reduce(op::Op, arr::Array{T}, initf=()->zero(T); grainsize=16) where {T, Op}
if length(arr) <= grainsize
return reduce(op, arr; init=initf()::T)
end
reducer = ScopedValue(Reducer(op, initf))

function reduce_impl(arr, range, grainsize)
if length(range) <= grainsize
reducer[][] = reduce(op, view(arr, range); init=initf()::T)
return
end
midpoint = length(range) ÷ 2
range_a = first(range):(first(range)+midpoint)
range_b = (first(range)+midpoint+1):last(range)

@sync begin
@spawn split(reducer) do
reduce_impl(arr, range_a, grainsize)
end
reduce_impl(arr, range_b, grainsize)
end
end
reduce_impl(arr, 1:length(arr), grainsize)
return reducer[][]
end

@test splitting_reduce(+, ones(1024)) == 1024
2 changes: 2 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -120,3 +120,5 @@ end
end
end

include("reducers.jl")