diff --git a/src/Dagger.jl b/src/Dagger.jl index fa30c7c1..102a7614 100644 --- a/src/Dagger.jl +++ b/src/Dagger.jl @@ -73,6 +73,9 @@ include("utils/fetch.jl") include("utils/chunks.jl") include("utils/logging.jl") include("submission.jl") +abstract type MemorySpace end +include("utils/memory-span.jl") +include("utils/interval_tree.jl") include("memory-spaces.jl") # Task scheduling @@ -83,7 +86,12 @@ include("utils/caching.jl") include("sch/Sch.jl"); using .Sch # Data dependency task queue -include("datadeps.jl") +include("datadeps/aliasing.jl") +include("datadeps/chunkview.jl") +include("datadeps/remainders.jl") +include("datadeps/queue.jl") + +# Stencils include("utils/haloarray.jl") include("stencil.jl") diff --git a/src/argument.jl b/src/argument.jl index 94246a75..849486e0 100644 --- a/src/argument.jl +++ b/src/argument.jl @@ -20,6 +20,7 @@ function pos_kw(pos::ArgPosition) @assert pos.kw != :NULL return pos.kw end + mutable struct Argument pos::ArgPosition value @@ -41,6 +42,35 @@ function Base.iterate(arg::Argument, state::Bool) return nothing end end - Base.copy(arg::Argument) = Argument(ArgPosition(arg.pos), arg.value) chunktype(arg::Argument) = chunktype(value(arg)) + +mutable struct TypedArgument{T} + pos::ArgPosition + value::T +end +TypedArgument(pos::Integer, value::T) where T = TypedArgument{T}(ArgPosition(true, pos, :NULL), value) +TypedArgument(kw::Symbol, value::T) where T = TypedArgument{T}(ArgPosition(false, 0, kw), value) +Base.setproperty!(arg::TypedArgument, name::Symbol, value::T) where T = + throw(ArgumentError("Cannot set properties of TypedArgument")) +ispositional(arg::TypedArgument) = ispositional(arg.pos) +iskw(arg::TypedArgument) = iskw(arg.pos) +pos_idx(arg::TypedArgument) = pos_idx(arg.pos) +pos_kw(arg::TypedArgument) = pos_kw(arg.pos) +raw_position(arg::TypedArgument) = raw_position(arg.pos) +value(arg::TypedArgument) = arg.value +valuetype(arg::TypedArgument{T}) where T = T +Base.iterate(arg::TypedArgument) = (arg.pos, true) +function Base.iterate(arg::TypedArgument, state::Bool) + if state + return (arg.value, false) + else + return nothing + end +end +Base.copy(arg::TypedArgument{T}) where T = TypedArgument{T}(ArgPosition(arg.pos), arg.value) +chunktype(arg::TypedArgument) = chunktype(value(arg)) + +Argument(arg::TypedArgument) = Argument(arg.pos, arg.value) + +const AnyArgument = Union{Argument, TypedArgument} \ No newline at end of file diff --git a/src/datadeps.jl b/src/datadeps.jl deleted file mode 100644 index d20bda64..00000000 --- a/src/datadeps.jl +++ /dev/null @@ -1,1082 +0,0 @@ -import Graphs: SimpleDiGraph, add_edge!, add_vertex!, inneighbors, outneighbors, nv - -export In, Out, InOut, Deps, spawn_datadeps - -"Specifies a read-only dependency." -struct In{T} - x::T -end -"Specifies a write-only dependency." -struct Out{T} - x::T -end -"Specifies a read-write dependency." -struct InOut{T} - x::T -end -"Specifies one or more dependencies." -struct Deps{T,DT<:Tuple} - x::T - deps::DT -end -Deps(x, deps...) = Deps(x, deps) - -struct DataDepsTaskQueue <: AbstractTaskQueue - # The queue above us - upper_queue::AbstractTaskQueue - # The set of tasks that have already been seen - seen_tasks::Union{Vector{Pair{DTaskSpec,DTask}},Nothing} - # The data-dependency graph of all tasks - g::Union{SimpleDiGraph{Int},Nothing} - # The mapping from task to graph ID - task_to_id::Union{Dict{DTask,Int},Nothing} - # How to traverse the dependency graph when launching tasks - traversal::Symbol - # Which scheduler to use to assign tasks to processors - scheduler::Symbol - - # Whether aliasing across arguments is possible - # The fields following only apply when aliasing==true - aliasing::Bool - - function DataDepsTaskQueue(upper_queue; - traversal::Symbol=:inorder, - scheduler::Symbol=:naive, - aliasing::Bool=true) - seen_tasks = Pair{DTaskSpec,DTask}[] - g = SimpleDiGraph() - task_to_id = Dict{DTask,Int}() - return new(upper_queue, seen_tasks, g, task_to_id, traversal, scheduler, - aliasing) - end -end - -function unwrap_inout(arg) - readdep = false - writedep = false - if arg isa In - readdep = true - arg = arg.x - elseif arg isa Out - writedep = true - arg = arg.x - elseif arg isa InOut - readdep = true - writedep = true - arg = arg.x - elseif arg isa Deps - alldeps = Tuple[] - for dep in arg.deps - dep_mod, inner_deps = unwrap_inout(dep) - for (_, readdep, writedep) in inner_deps - push!(alldeps, (dep_mod, readdep, writedep)) - end - end - arg = arg.x - return arg, alldeps - else - readdep = true - end - return arg, Tuple[(identity, readdep, writedep)] -end - -function enqueue!(queue::DataDepsTaskQueue, spec::Pair{DTaskSpec,DTask}) - push!(queue.seen_tasks, spec) -end -function enqueue!(queue::DataDepsTaskQueue, specs::Vector{Pair{DTaskSpec,DTask}}) - append!(queue.seen_tasks, specs) -end - -_identity_hash(arg, h::UInt=UInt(0)) = ismutable(arg) ? objectid(arg) : hash(arg, h) -_identity_hash(arg::SubArray, h::UInt=UInt(0)) = hash(arg.indices, hash(arg.offset1, hash(arg.stride1, _identity_hash(arg.parent, h)))) -_identity_hash(arg::CartesianIndices, h::UInt=UInt(0)) = hash(arg.indices, hash(typeof(arg), h)) - -struct ArgumentWrapper - arg - dep_mod - hash::UInt - - function ArgumentWrapper(arg, dep_mod) - h = hash(dep_mod) - h = _identity_hash(arg, h) - return new(arg, dep_mod, h) - end -end -Base.hash(aw::ArgumentWrapper) = hash(ArgumentWrapper, aw.hash) -Base.:(==)(aw1::ArgumentWrapper, aw2::ArgumentWrapper) = - aw1.hash == aw2.hash - -struct DataDepsAliasingState - # Track original and current data locations - # We track data => space - data_origin::Dict{AliasingWrapper,MemorySpace} - data_locality::Dict{AliasingWrapper,MemorySpace} - - # Track writers ("owners") and readers - ainfos_owner::Dict{AliasingWrapper,Union{Pair{DTask,Int},Nothing}} - ainfos_readers::Dict{AliasingWrapper,Vector{Pair{DTask,Int}}} - ainfos_overlaps::Dict{AliasingWrapper,Set{AliasingWrapper}} - - # Cache ainfo lookups - ainfo_cache::Dict{ArgumentWrapper,AliasingWrapper} - - function DataDepsAliasingState() - data_origin = Dict{AliasingWrapper,MemorySpace}() - data_locality = Dict{AliasingWrapper,MemorySpace}() - - ainfos_owner = Dict{AliasingWrapper,Union{Pair{DTask,Int},Nothing}}() - ainfos_readers = Dict{AliasingWrapper,Vector{Pair{DTask,Int}}}() - ainfos_overlaps = Dict{AliasingWrapper,Set{AliasingWrapper}}() - - ainfo_cache = Dict{ArgumentWrapper,AliasingWrapper}() - - return new(data_origin, data_locality, - ainfos_owner, ainfos_readers, ainfos_overlaps, - ainfo_cache) - end -end -struct DataDepsNonAliasingState - # Track original and current data locations - # We track data => space - data_origin::IdDict{Any,MemorySpace} - data_locality::IdDict{Any,MemorySpace} - - # Track writers ("owners") and readers - args_owner::IdDict{Any,Union{Pair{DTask,Int},Nothing}} - args_readers::IdDict{Any,Vector{Pair{DTask,Int}}} - - function DataDepsNonAliasingState() - data_origin = IdDict{Any,MemorySpace}() - data_locality = IdDict{Any,MemorySpace}() - - args_owner = IdDict{Any,Union{Pair{DTask,Int},Nothing}}() - args_readers = IdDict{Any,Vector{Pair{DTask,Int}}}() - - return new(data_origin, data_locality, - args_owner, args_readers) - end -end -struct DataDepsState{State<:Union{DataDepsAliasingState,DataDepsNonAliasingState}} - # Whether aliasing is being analyzed - aliasing::Bool - - # The ordered list of tasks and their read/write dependencies - dependencies::Vector{Pair{DTask,Vector{Tuple{Bool,Bool,AliasingWrapper,<:Any,<:Any}}}} - - # The mapping of memory space to remote argument copies - remote_args::Dict{MemorySpace,IdDict{Any,Any}} - - # Cache of whether arguments supports in-place move - supports_inplace_cache::IdDict{Any,Bool} - - # The aliasing analysis state - alias_state::State - - function DataDepsState(aliasing::Bool) - dependencies = Pair{DTask,Vector{Tuple{Bool,Bool,AliasingWrapper,<:Any,<:Any}}}[] - remote_args = Dict{MemorySpace,IdDict{Any,Any}}() - supports_inplace_cache = IdDict{Any,Bool}() - if aliasing - state = DataDepsAliasingState() - else - state = DataDepsNonAliasingState() - end - return new{typeof(state)}(aliasing, dependencies, remote_args, supports_inplace_cache, state) - end -end - -function aliasing(astate::DataDepsAliasingState, arg, dep_mod) - aw = ArgumentWrapper(arg, dep_mod) - get!(astate.ainfo_cache, aw) do - return AliasingWrapper(aliasing(arg, dep_mod)) - end -end - -function supports_inplace_move(state::DataDepsState, arg) - return get!(state.supports_inplace_cache, arg) do - return supports_inplace_move(arg) - end -end - -# Determine which arguments could be written to, and thus need tracking - -"Whether `arg` has any writedep in this datadeps region." -function has_writedep(state::DataDepsState{DataDepsNonAliasingState}, arg, deps) - # Check if we are writing to this memory - writedep = any(dep->dep[3], deps) - if writedep - arg_has_writedep[arg] = true - return true - end - - # Check if another task is writing to this memory - for (_, taskdeps) in state.dependencies - for (_, other_arg_writedep, _, _, other_arg) in taskdeps - other_arg_writedep || continue - if arg === other_arg - return true - end - end - end - - return false -end -""" -Whether `arg` has any writedep at or before executing `task` in this -datadeps region. -""" -function has_writedep(state::DataDepsState, arg, deps, task::DTask) - is_writedep(arg, deps, task) && return true - if state.aliasing - for (other_task, other_taskdeps) in state.dependencies - for (readdep, writedep, other_ainfo, _, _) in other_taskdeps - writedep || continue - for (dep_mod, _, _) in deps - ainfo = aliasing(state.alias_state, arg, dep_mod) - if will_alias(ainfo, other_ainfo) - return true - end - end - end - if task === other_task - return false - end - end - else - for (other_task, other_taskdeps) in state.dependencies - for (readdep, writedep, _, _, other_arg) in other_taskdeps - writedep || continue - if arg === other_arg - return true - end - end - if task === other_task - return false - end - end - end - error("Task isn't in argdeps set") -end -"Whether `arg` is written to by `task`." -function is_writedep(arg, deps, task::DTask) - return any(dep->dep[3], deps) -end - -# Aliasing state setup -function populate_task_info!(state::DataDepsState, spec::DTaskSpec, task::DTask) - # Populate task dependencies - dependencies_to_add = Vector{Tuple{Bool,Bool,AliasingWrapper,<:Any,<:Any}}() - - # Track the task's arguments and access patterns - for (idx, _arg) in enumerate(spec.fargs) - # Unwrap In/InOut/Out wrappers and record dependencies - arg, deps = unwrap_inout(value(_arg)) - - # Unwrap the Chunk underlying any DTask arguments - arg = arg isa DTask ? fetch(arg; raw=true) : arg - - # Skip non-aliasing arguments - type_may_alias(typeof(arg)) || continue - - # Add all aliasing dependencies - for (dep_mod, readdep, writedep) in deps - if state.aliasing - ainfo = aliasing(state.alias_state, arg, dep_mod) - else - ainfo = AliasingWrapper(UnknownAliasing()) - end - push!(dependencies_to_add, (readdep, writedep, ainfo, dep_mod, arg)) - end - - # Populate argument write info - populate_argument_info!(state, arg, deps) - end - - # Track the task result too - # N.B. We state no readdep/writedep because, while we can't model the aliasing info for the task result yet, we don't want to synchronize because of this - push!(dependencies_to_add, (false, false, AliasingWrapper(UnknownAliasing()), identity, task)) - - # Record argument/result dependencies - push!(state.dependencies, task => dependencies_to_add) -end -function populate_argument_info!(state::DataDepsState{DataDepsAliasingState}, arg, deps) - astate = state.alias_state - for (dep_mod, readdep, writedep) in deps - ainfo = aliasing(astate, arg, dep_mod) - - # Initialize owner and readers - if !haskey(astate.ainfos_owner, ainfo) - overlaps = Set{AliasingWrapper}() - push!(overlaps, ainfo) - for other_ainfo in keys(astate.ainfos_owner) - ainfo == other_ainfo && continue - if will_alias(ainfo, other_ainfo) - push!(overlaps, other_ainfo) - push!(astate.ainfos_overlaps[other_ainfo], ainfo) - end - end - astate.ainfos_overlaps[ainfo] = overlaps - astate.ainfos_owner[ainfo] = nothing - astate.ainfos_readers[ainfo] = Pair{DTask,Int}[] - end - - # Assign data owner and locality - if !haskey(astate.data_locality, ainfo) - astate.data_locality[ainfo] = memory_space(arg) - astate.data_origin[ainfo] = memory_space(arg) - end - end -end -function populate_argument_info!(state::DataDepsState{DataDepsNonAliasingState}, arg, deps) - astate = state.alias_state - # Initialize owner and readers - if !haskey(astate.args_owner, arg) - astate.args_owner[arg] = nothing - astate.args_readers[arg] = DTask[] - end - - # Assign data owner and locality - if !haskey(astate.data_locality, arg) - astate.data_locality[arg] = memory_space(arg) - astate.data_origin[arg] = memory_space(arg) - end -end -function populate_return_info!(state::DataDepsState{DataDepsAliasingState}, task, space) - astate = state.alias_state - @assert !haskey(astate.data_locality, task) - # FIXME: We don't yet know about ainfos for this task -end -function populate_return_info!(state::DataDepsState{DataDepsNonAliasingState}, task, space) - astate = state.alias_state - @assert !haskey(astate.data_locality, task) - astate.data_locality[task] = space - astate.data_origin[task] = space -end - -""" - supports_inplace_move(x) -> Bool - -Returns `false` if `x` doesn't support being copied into from another object -like `x`, via `move!`. This is used in `spawn_datadeps` to prevent attempting -to copy between values which don't support mutation or otherwise don't have an -implemented `move!` and want to skip in-place copies. When this returns -`false`, datadeps will instead perform out-of-place copies for each non-local -use of `x`, and the data in `x` will not be updated when the `spawn_datadeps` -region returns. -""" -supports_inplace_move(x) = true -supports_inplace_move(t::DTask) = supports_inplace_move(fetch(t; raw=true)) -function supports_inplace_move(c::Chunk) - # FIXME: Use MemPool.access_ref - pid = root_worker_id(c.processor) - if pid == myid() - return supports_inplace_move(poolget(c.handle)) - else - return remotecall_fetch(supports_inplace_move, pid, c) - end -end -supports_inplace_move(::Function) = false - -# Read/write dependency management -function get_write_deps!(state::DataDepsState, ainfo_or_arg, task, write_num, syncdeps) - _get_write_deps!(state, ainfo_or_arg, task, write_num, syncdeps) - _get_read_deps!(state, ainfo_or_arg, task, write_num, syncdeps) -end -function get_read_deps!(state::DataDepsState, ainfo_or_arg, task, write_num, syncdeps) - _get_write_deps!(state, ainfo_or_arg, task, write_num, syncdeps) -end - -function _get_write_deps!(state::DataDepsState{DataDepsAliasingState}, ainfo::AbstractAliasing, task, write_num, syncdeps) - astate = state.alias_state - ainfo.inner isa NoAliasing && return - for other_ainfo in astate.ainfos_overlaps[ainfo] - other_task_write_num = astate.ainfos_owner[other_ainfo] - @dagdebug nothing :spawn_datadeps "Considering sync with writer via $ainfo -> $other_ainfo" - other_task_write_num === nothing && continue - other_task, other_write_num = other_task_write_num - write_num == other_write_num && continue - @dagdebug nothing :spawn_datadeps "Sync with writer via $ainfo -> $other_ainfo" - push!(syncdeps, ThunkSyncdep(other_task)) - end -end -function _get_read_deps!(state::DataDepsState{DataDepsAliasingState}, ainfo::AbstractAliasing, task, write_num, syncdeps) - astate = state.alias_state - ainfo.inner isa NoAliasing && return - for other_ainfo in astate.ainfos_overlaps[ainfo] - @dagdebug nothing :spawn_datadeps "Considering sync with reader via $ainfo -> $other_ainfo" - other_tasks = astate.ainfos_readers[other_ainfo] - for (other_task, other_write_num) in other_tasks - write_num == other_write_num && continue - @dagdebug nothing :spawn_datadeps "Sync with reader via $ainfo -> $other_ainfo" - push!(syncdeps, ThunkSyncdep(other_task)) - end - end -end -function add_writer!(state::DataDepsState{DataDepsAliasingState}, ainfo::AbstractAliasing, task, write_num) - state.alias_state.ainfos_owner[ainfo] = task=>write_num - empty!(state.alias_state.ainfos_readers[ainfo]) - # Not necessary to assert a read, but conceptually it's true - add_reader!(state, ainfo, task, write_num) -end -function add_reader!(state::DataDepsState{DataDepsAliasingState}, ainfo::AbstractAliasing, task, write_num) - push!(state.alias_state.ainfos_readers[ainfo], task=>write_num) -end - -function _get_write_deps!(state::DataDepsState{DataDepsNonAliasingState}, arg, task, write_num, syncdeps) - other_task_write_num = state.alias_state.args_owner[arg] - if other_task_write_num !== nothing - other_task, other_write_num = other_task_write_num - if write_num != other_write_num - push!(syncdeps, ThunkSyncdep(other_task)) - end - end -end -function _get_read_deps!(state::DataDepsState{DataDepsNonAliasingState}, arg, task, write_num, syncdeps) - for (other_task, other_write_num) in state.alias_state.args_readers[arg] - if write_num != other_write_num - push!(syncdeps, ThunkSyncdep(other_task)) - end - end -end -function add_writer!(state::DataDepsState{DataDepsNonAliasingState}, arg, task, write_num) - state.alias_state.args_owner[arg] = task=>write_num - empty!(state.alias_state.args_readers[arg]) - # Not necessary to assert a read, but conceptually it's true - add_reader!(state, arg, task, write_num) -end -function add_reader!(state::DataDepsState{DataDepsNonAliasingState}, arg, task, write_num) - push!(state.alias_state.args_readers[arg], task=>write_num) -end - -# Make a copy of each piece of data on each worker -# memory_space => {arg => copy_of_arg} -isremotehandle(x) = false -isremotehandle(x::DTask) = true -isremotehandle(x::Chunk) = true -function generate_slot!(state::DataDepsState, dest_space, data) - if data isa DTask - data = fetch(data; raw=true) - end - orig_space = memory_space(data) - to_proc = first(processors(dest_space)) - from_proc = first(processors(orig_space)) - dest_space_args = get!(IdDict{Any,Any}, state.remote_args, dest_space) - if orig_space == dest_space && (data isa Chunk || !isremotehandle(data)) - # Fast path for local data or data already in a Chunk - data_chunk = tochunk(data, from_proc) - dest_space_args[data] = data_chunk - @assert processor(data_chunk) in processors(dest_space) || data isa Chunk && processor(data) isa Dagger.OSProc - @assert memory_space(data_chunk) == orig_space - else - to_w = root_worker_id(dest_space) - ctx = Sch.eager_context() - id = rand(Int) - dest_space_args[data] = remotecall_fetch(to_w, from_proc, to_proc, data) do from_proc, to_proc, data - timespan_start(ctx, :move, (;thunk_id=0, id, position=ArgPosition(), processor=to_proc), (;f=nothing, data)) - data_converted = move(from_proc, to_proc, data) - data_chunk = tochunk(data_converted, to_proc) - @assert processor(data_chunk) in processors(dest_space) - @assert memory_space(data_converted) == memory_space(data_chunk) "space mismatch! $(memory_space(data_converted)) != $(memory_space(data_chunk)) ($(typeof(data_converted)) vs. $(typeof(data_chunk))), spaces ($orig_space -> $dest_space)" - if orig_space != dest_space - @assert orig_space != memory_space(data_chunk) "space preserved! $orig_space != $(memory_space(data_chunk)) ($(typeof(data)) vs. $(typeof(data_chunk))), spaces ($orig_space -> $dest_space)" - end - return data_chunk - end - timespan_finish(ctx, :move, (;thunk_id=0, id, position=ArgPosition(), processor=to_proc), (;f=nothing, data=dest_space_args[data])) - end - return dest_space_args[data] -end - -struct DataDepsSchedulerState - task_to_spec::Dict{DTask,DTaskSpec} - assignments::Dict{DTask,MemorySpace} - dependencies::Dict{DTask,Set{DTask}} - task_completions::Dict{DTask,UInt64} - space_completions::Dict{MemorySpace,UInt64} - capacities::Dict{MemorySpace,Int} - - function DataDepsSchedulerState() - return new(Dict{DTask,DTaskSpec}(), - Dict{DTask,MemorySpace}(), - Dict{DTask,Set{DTask}}(), - Dict{DTask,UInt64}(), - Dict{MemorySpace,UInt64}(), - Dict{MemorySpace,Int}()) - end -end - -function distribute_tasks!(queue::DataDepsTaskQueue) - #= TODO: Improvements to be made: - # - Support for copying non-AbstractArray arguments - # - Parallelize read copies - # - Unreference unused slots - # - Reuse memory when possible - # - Account for differently-sized data - =# - - # Get the set of all processors to be scheduled on - all_procs = Processor[] - scope = get_compute_scope() - for w in procs() - append!(all_procs, get_processors(OSProc(w))) - end - filter!(proc->!isa(constrain(ExactScope(proc), scope), - InvalidScope), - all_procs) - if isempty(all_procs) - throw(Sch.SchedulingException("No processors available, try widening scope")) - end - exec_spaces = unique(vcat(map(proc->collect(memory_spaces(proc)), all_procs)...)) - if !all(space->space isa CPURAMMemorySpace, exec_spaces) && !all(space->root_worker_id(space) == myid(), exec_spaces) - @warn "Datadeps support for multi-GPU, multi-worker is currently broken\nPlease be prepared for incorrect results or errors" maxlog=1 - end - - # Round-robin assign tasks to processors - upper_queue = get_options(:task_queue) - - traversal = queue.traversal - if traversal == :inorder - # As-is - task_order = Colon() - elseif traversal == :bfs - # BFS - task_order = Int[1] - to_walk = Int[1] - seen = Set{Int}([1]) - while !isempty(to_walk) - # N.B. next_root has already been seen - next_root = popfirst!(to_walk) - for v in outneighbors(queue.g, next_root) - if !(v in seen) - push!(task_order, v) - push!(seen, v) - push!(to_walk, v) - end - end - end - elseif traversal == :dfs - # DFS (modified with backtracking) - task_order = Int[] - to_walk = Int[1] - seen = Set{Int}() - while length(task_order) < length(queue.seen_tasks) && !isempty(to_walk) - next_root = popfirst!(to_walk) - if !(next_root in seen) - iv = inneighbors(queue.g, next_root) - if all(v->v in seen, iv) - push!(task_order, next_root) - push!(seen, next_root) - ov = outneighbors(queue.g, next_root) - prepend!(to_walk, ov) - else - push!(to_walk, next_root) - end - end - end - else - throw(ArgumentError("Invalid traversal mode: $traversal")) - end - - state = DataDepsState(queue.aliasing) - astate = state.alias_state - sstate = DataDepsSchedulerState() - for proc in all_procs - space = only(memory_spaces(proc)) - get!(()->0, sstate.capacities, space) - sstate.capacities[space] += 1 - end - - # Start launching tasks and necessary copies - write_num = 1 - proc_idx = 1 - pressures = Dict{Processor,Int}() - proc_to_scope_lfu = BasicLFUCache{Processor,AbstractScope}(1024) - for (spec, task) in queue.seen_tasks[task_order] - # Populate all task dependencies - populate_task_info!(state, spec, task) - - task_scope = @something(spec.options.compute_scope, spec.options.scope, DefaultScope()) - scheduler = queue.scheduler - if scheduler == :naive - raw_args = map(arg->tochunk(value(arg)), spec.fargs) - our_proc = remotecall_fetch(1, all_procs, raw_args) do all_procs, raw_args - Sch.init_eager() - sch_state = Sch.EAGER_STATE[] - - @lock sch_state.lock begin - # Calculate costs per processor and select the most optimal - # FIXME: This should consider any already-allocated slots, - # whether they are up-to-date, and if not, the cost of moving - # data to them - procs, costs = Sch.estimate_task_costs(sch_state, all_procs, nothing, raw_args) - return first(procs) - end - end - elseif scheduler == :smart - raw_args = map(filter(arg->haskey(astate.data_locality, value(arg)), spec.fargs)) do arg - arg_chunk = tochunk(last(arg)) - # Only the owned slot is valid - # FIXME: Track up-to-date copies and pass all of those - return arg_chunk => data_locality[arg] - end - f_chunk = tochunk(value(f)) - our_proc, task_pressure = remotecall_fetch(1, all_procs, pressures, f_chunk, raw_args) do all_procs, pressures, f, chunks_locality - Sch.init_eager() - sch_state = Sch.EAGER_STATE[] - - @lock sch_state.lock begin - tx_rate = sch_state.transfer_rate[] - - costs = Dict{Processor,Float64}() - for proc in all_procs - # Filter out chunks that are already local - chunks_filt = Iterators.filter(((chunk, space)=chunk_locality)->!(proc in processors(space)), chunks_locality) - - # Estimate network transfer costs based on data size - # N.B. `affinity(x)` really means "data size of `x`" - # N.B. We treat same-worker transfers as having zero transfer cost - tx_cost = Sch.impute_sum(affinity(chunk)[2] for chunk in chunks_filt) - - # Estimate total cost to move data and get task running after currently-scheduled tasks - est_time_util = get(pressures, proc, UInt64(0)) - costs[proc] = est_time_util + (tx_cost/tx_rate) - end - - # Look up estimated task cost - sig = Sch.signature(sch_state, f, map(first, chunks_locality)) - task_pressure = get(sch_state.signature_time_cost, sig, 1000^3) - - # Shuffle procs around, so equally-costly procs are equally considered - P = randperm(length(all_procs)) - procs = getindex.(Ref(all_procs), P) - - # Sort by lowest cost first - sort!(procs, by=p->costs[p]) - - best_proc = first(procs) - return best_proc, task_pressure - end - end - # FIXME: Pressure should be decreased by pressure of syncdeps on same processor - pressures[our_proc] = get(pressures, our_proc, UInt64(0)) + task_pressure - elseif scheduler == :ultra - args = Base.mapany(spec.fargs) do arg - pos, data = arg - data, _ = unwrap_inout(data) - if data isa DTask - data = fetch(data; raw=true) - end - return pos => tochunk(data) - end - f_chunk = tochunk(value(f)) - task_time = remotecall_fetch(1, f_chunk, args) do f, args - Sch.init_eager() - sch_state = Sch.EAGER_STATE[] - return @lock sch_state.lock begin - sig = Sch.signature(sch_state, f, args) - return get(sch_state.signature_time_cost, sig, 1000^3) - end - end - - # FIXME: Copy deps are computed eagerly - deps = get(Set{Any}, spec.options, :syncdeps) - - # Find latest time-to-completion of all syncdeps - deps_completed = UInt64(0) - for dep in deps - haskey(sstate.task_completions, dep) || continue # copy deps aren't recorded - deps_completed = max(deps_completed, sstate.task_completions[dep]) - end - - # Find latest time-to-completion of each memory space - # FIXME: Figure out space completions based on optimal packing - spaces_completed = Dict{MemorySpace,UInt64}() - for space in exec_spaces - completed = UInt64(0) - for (task, other_space) in sstate.assignments - space == other_space || continue - completed = max(completed, sstate.task_completions[task]) - end - spaces_completed[space] = completed - end - - # Choose the earliest-available memory space and processor - # FIXME: Consider move time - move_time = UInt64(0) - local our_space_completed - while true - our_space_completed, our_space = findmin(spaces_completed) - our_space_procs = filter(proc->proc in all_procs, processors(our_space)) - if isempty(our_space_procs) - delete!(spaces_completed, our_space) - continue - end - our_proc = rand(our_space_procs) - break - end - - sstate.task_to_spec[task] = spec - sstate.assignments[task] = our_space - sstate.task_completions[task] = our_space_completed + move_time + task_time - elseif scheduler == :roundrobin - our_proc = all_procs[proc_idx] - if task_scope == scope - # all_procs is already limited to scope - else - if isa(constrain(task_scope, scope), InvalidScope) - throw(Sch.SchedulingException("Scopes are not compatible: $(scope), $(task_scope)")) - end - while !proc_in_scope(our_proc, task_scope) - proc_idx = mod1(proc_idx + 1, length(all_procs)) - our_proc = all_procs[proc_idx] - end - end - else - error("Invalid scheduler: $sched") - end - @assert our_proc in all_procs - our_space = only(memory_spaces(our_proc)) - - # Find the scope for this task (and its copies) - if task_scope == scope - # Optimize for the common case, cache the proc=>scope mapping - our_scope = get!(proc_to_scope_lfu, our_proc) do - our_procs = filter(proc->proc in all_procs, collect(processors(our_space))) - return constrain(UnionScope(map(ExactScope, our_procs)...), scope) - end - else - # Use the provided scope and constrain it to the available processors - our_procs = filter(proc->proc in all_procs, collect(processors(our_space))) - our_scope = constrain(UnionScope(map(ExactScope, our_procs)...), task_scope) - end - if our_scope isa InvalidScope - throw(Sch.SchedulingException("Scopes are not compatible: $(our_scope.x), $(our_scope.y)")) - end - - f = spec.fargs[1] - f.value = move(ThreadProc(myid(), 1), our_proc, value(f)) - @dagdebug nothing :spawn_datadeps "($(repr(value(f)))) Scheduling: $our_proc ($our_space)" - - # Copy raw task arguments for analysis - task_args = map(copy, spec.fargs) - - # Copy args from local to remote - for (idx, _arg) in enumerate(task_args) - # Is the data writeable? - arg, deps = unwrap_inout(value(_arg)) - arg = arg isa DTask ? fetch(arg; raw=true) : arg - if !type_may_alias(typeof(arg)) - @dagdebug nothing :spawn_datadeps "($(repr(value(f))))[$idx] Skipped copy-to (unwritten)" - spec.fargs[idx].value = arg - continue - end - if !supports_inplace_move(state, arg) - @dagdebug nothing :spawn_datadeps "($(repr(value(f))))[$idx] Skipped copy-to (non-writeable)" - spec.fargs[idx].value = arg - continue - end - - # Is the source of truth elsewhere? - arg_remote = get!(get!(IdDict{Any,Any}, state.remote_args, our_space), arg) do - generate_slot!(state, our_space, arg) - end - if queue.aliasing - for (dep_mod, _, _) in deps - ainfo = aliasing(astate, arg, dep_mod) - data_space = astate.data_locality[ainfo] - nonlocal = our_space != data_space - if nonlocal - # Add copy-to operation (depends on latest owner of arg) - @dagdebug nothing :spawn_datadeps "($(repr(value(f))))[$idx][$dep_mod] Enqueueing copy-to: $data_space => $our_space" - arg_local = get!(get!(IdDict{Any,Any}, state.remote_args, data_space), arg) do - generate_slot!(state, data_space, arg) - end - copy_to_scope = our_scope - copy_to_syncdeps = Set{ThunkSyncdep}() - get_write_deps!(state, ainfo, task, write_num, copy_to_syncdeps) - @dagdebug nothing :spawn_datadeps "($(repr(value(f))))[$idx][$dep_mod] $(length(copy_to_syncdeps)) syncdeps" - copy_to = Dagger.@spawn scope=copy_to_scope exec_scope=copy_to_scope syncdeps=copy_to_syncdeps meta=true Dagger.move!(dep_mod, our_space, data_space, arg_remote, arg_local) - add_writer!(state, ainfo, copy_to, write_num) - - astate.data_locality[ainfo] = our_space - else - @dagdebug nothing :spawn_datadeps "($(repr(value(f))))[$idx][$dep_mod] Skipped copy-to (local): $data_space" - end - end - else - data_space = astate.data_locality[arg] - nonlocal = our_space != data_space - if nonlocal - # Add copy-to operation (depends on latest owner of arg) - @dagdebug nothing :spawn_datadeps "($(repr(value(f))))[$idx] Enqueueing copy-to: $data_space => $our_space" - arg_local = get!(get!(IdDict{Any,Any}, state.remote_args, data_space), arg) do - generate_slot!(state, data_space, arg) - end - copy_to_scope = our_scope - copy_to_syncdeps = Set{ThunkSyncdep}() - get_write_deps!(state, arg, task, write_num, copy_to_syncdeps) - @dagdebug nothing :spawn_datadeps "($(repr(value(f))))[$idx] $(length(copy_to_syncdeps)) syncdeps" - copy_to = Dagger.@spawn scope=copy_to_scope exec_scope=copy_to_scope syncdeps=copy_to_syncdeps Dagger.move!(identity, our_space, data_space, arg_remote, arg_local) - add_writer!(state, arg, copy_to, write_num) - - astate.data_locality[arg] = our_space - else - @dagdebug nothing :spawn_datadeps "($(repr(value(f))))[$idx] Skipped copy-to (local): $data_space" - end - end - spec.fargs[idx].value = arg_remote - end - write_num += 1 - - # Validate that we're not accidentally performing a copy - for (idx, _arg) in enumerate(spec.fargs) - _, deps = unwrap_inout(value(task_args[idx])) - # N.B. We only do this check when the argument supports in-place - # moves, because for the moment, we are not guaranteeing updates or - # write-back of results - arg = value(_arg) - if is_writedep(arg, deps, task) && supports_inplace_move(state, arg) - arg_space = memory_space(arg) - @assert arg_space == our_space "($(repr(value(f))))[$idx] Tried to pass $(typeof(arg)) from $arg_space to $our_space" - end - end - - # Calculate this task's syncdeps - if spec.options.syncdeps === nothing - spec.options.syncdeps = Set{ThunkSyncdep}() - end - syncdeps = spec.options.syncdeps - for (idx, (_, arg)) in enumerate(task_args) - arg, deps = unwrap_inout(arg) - arg = arg isa DTask ? fetch(arg; raw=true) : arg - type_may_alias(typeof(arg)) || continue - supports_inplace_move(state, arg) || continue - if queue.aliasing - for (dep_mod, _, writedep) in deps - ainfo = aliasing(astate, arg, dep_mod) - if writedep - @dagdebug nothing :spawn_datadeps "($(repr(value(f))))[$idx][$dep_mod] Syncing as writer" - get_write_deps!(state, ainfo, task, write_num, syncdeps) - else - @dagdebug nothing :spawn_datadeps "($(repr(value(f))))[$idx][$dep_mod] Syncing as reader" - get_read_deps!(state, ainfo, task, write_num, syncdeps) - end - end - else - if is_writedep(arg, deps, task) - @dagdebug nothing :spawn_datadeps "($(repr(value(f))))[$idx] Syncing as writer" - get_write_deps!(state, arg, task, write_num, syncdeps) - else - @dagdebug nothing :spawn_datadeps "($(repr(value(f))))[$idx] Syncing as reader" - get_read_deps!(state, arg, task, write_num, syncdeps) - end - end - end - @dagdebug nothing :spawn_datadeps "($(repr(value(f)))) $(length(syncdeps)) syncdeps" - - # Launch user's task - spec.options.scope = our_scope - spec.options.exec_scope = our_scope - enqueue!(upper_queue, spec=>task) - - # Update read/write tracking for arguments - for (idx, (_, arg)) in enumerate(task_args) - arg, deps = unwrap_inout(arg) - arg = arg isa DTask ? fetch(arg; raw=true) : arg - type_may_alias(typeof(arg)) || continue - if queue.aliasing - for (dep_mod, _, writedep) in deps - ainfo = aliasing(astate, arg, dep_mod) - if writedep - @dagdebug nothing :spawn_datadeps "($(repr(value(f))))[$idx][$dep_mod] Set as owner" - add_writer!(state, ainfo, task, write_num) - else - add_reader!(state, ainfo, task, write_num) - end - end - else - if is_writedep(arg, deps, task) - @dagdebug nothing :spawn_datadeps "($(repr(value(f))))[$idx] Set as owner" - add_writer!(state, arg, task, write_num) - else - add_reader!(state, arg, task, write_num) - end - end - end - - # Update tracking for return value - populate_return_info!(state, task, our_space) - - write_num += 1 - proc_idx = mod1(proc_idx + 1, length(all_procs)) - end - - # Copy args from remote to local - if queue.aliasing - # We need to replay the writes from all tasks in-order (skipping any - # outdated write owners), to ensure that overlapping writes are applied - # in the correct order - - # First, find the latest owners of each live ainfo - arg_writes = IdDict{Any,Vector{Tuple{AliasingWrapper,<:Any,MemorySpace}}}() - for (task, taskdeps) in state.dependencies - for (_, writedep, ainfo, dep_mod, arg) in taskdeps - writedep || continue - haskey(astate.data_locality, ainfo) || continue - @assert haskey(astate.ainfos_owner, ainfo) "Missing ainfo: $ainfo ($dep_mod($(typeof(arg))))" - - # Skip virtual writes from task result aliasing - # FIXME: Make this less bad - if arg isa DTask && dep_mod === identity && ainfo.inner isa UnknownAliasing - continue - end - - # Skip non-writeable arguments - if !supports_inplace_move(state, arg) - @dagdebug nothing :spawn_datadeps "Skipped copy-from (non-writeable)" - continue - end - - # Get the set of writers - ainfo_writes = get!(Vector{Tuple{AliasingWrapper,<:Any,MemorySpace}}, arg_writes, arg) - - #= FIXME: If we fully overlap any writer, evict them - idxs = findall(ainfo_write->overlaps_all(ainfo, ainfo_write[1]), ainfo_writes) - deleteat!(ainfo_writes, idxs) - =# - - # Make ourselves the latest writer - push!(ainfo_writes, (ainfo, dep_mod, astate.data_locality[ainfo])) - end - end - - # Then, replay the writes from each owner in-order - # FIXME: write_num should advance across overlapping ainfo's, as - # writes must be ordered sequentially - for (arg, ainfo_writes) in arg_writes - if length(ainfo_writes) > 1 - # FIXME: Remove me - deleteat!(ainfo_writes, 1:length(ainfo_writes)-1) - end - for (ainfo, dep_mod, data_remote_space) in ainfo_writes - # Is the source of truth elsewhere? - data_local_space = astate.data_origin[ainfo] - if data_local_space != data_remote_space - # Add copy-from operation - @dagdebug nothing :spawn_datadeps "[$dep_mod] Enqueueing copy-from: $data_remote_space => $data_local_space" - arg_local = get!(get!(IdDict{Any,Any}, state.remote_args, data_local_space), arg) do - generate_slot!(state, data_local_space, arg) - end - arg_remote = state.remote_args[data_remote_space][arg] - @assert arg_remote !== arg_local - data_local_proc = first(processors(data_local_space)) - copy_from_scope = UnionScope(map(ExactScope, collect(processors(data_local_space)))...) - copy_from_syncdeps = Set{ThunkSyncdep}() - get_write_deps!(state, ainfo, nothing, write_num, copy_from_syncdeps) - @dagdebug nothing :spawn_datadeps "$(length(copy_from_syncdeps)) syncdeps" - copy_from = Dagger.@spawn scope=copy_from_scope exec_scope=copy_from_scope syncdeps=copy_from_syncdeps meta=true Dagger.move!(dep_mod, data_local_space, data_remote_space, arg_local, arg_remote) - else - @dagdebug nothing :spawn_datadeps "[$dep_mod] Skipped copy-from (local): $data_remote_space" - end - end - end - else - for arg in keys(astate.data_origin) - # Is the data previously written? - arg, deps = unwrap_inout(arg) - if !type_may_alias(typeof(arg)) - @dagdebug nothing :spawn_datadeps "Skipped copy-from (immutable)" - end - - # Can the data be written back to? - if !supports_inplace_move(state, arg) - @dagdebug nothing :spawn_datadeps "Skipped copy-from (non-writeable)" - end - - # Is the source of truth elsewhere? - data_remote_space = astate.data_locality[arg] - data_local_space = astate.data_origin[arg] - if data_local_space != data_remote_space - # Add copy-from operation - @dagdebug nothing :spawn_datadeps "Enqueueing copy-from: $data_remote_space => $data_local_space" - arg_local = state.remote_args[data_local_space][arg] - arg_remote = state.remote_args[data_remote_space][arg] - @assert arg_remote !== arg_local - data_local_proc = first(processors(data_local_space)) - copy_from_scope = ExactScope(data_local_proc) - copy_from_syncdeps = Set{ThunkSyncdep}() - get_write_deps!(state, arg, nothing, write_num, copy_from_syncdeps) - @dagdebug nothing :spawn_datadeps "$(length(copy_from_syncdeps)) syncdeps" - copy_from = Dagger.@spawn scope=copy_from_scope exec_scope=copy_from_scope syncdeps=copy_from_syncdeps meta=true Dagger.move!(identity, data_local_space, data_remote_space, arg_local, arg_remote) - else - @dagdebug nothing :spawn_datadeps "Skipped copy-from (local): $data_remote_space" - end - end - end -end - -""" - spawn_datadeps(f::Base.Callable; traversal::Symbol=:inorder) - -Constructs a "datadeps" (data dependencies) region and calls `f` within it. -Dagger tasks launched within `f` may wrap their arguments with `In`, `Out`, or -`InOut` to indicate whether the task will read, write, or read+write that -argument, respectively. These argument dependencies will be used to specify -which tasks depend on each other based on the following rules: - -- Dependencies across unrelated arguments are independent; only dependencies on arguments which overlap in memory synchronize with each other -- `InOut` is the same as `In` and `Out` applied simultaneously, and synchronizes with the union of the `In` and `Out` effects -- Any two or more `In` dependencies do not synchronize with each other, and may execute in parallel -- An `Out` dependency synchronizes with any previous `In` and `Out` dependencies -- An `In` dependency synchronizes with any previous `Out` dependencies -- If unspecified, an `In` dependency is assumed - -In general, the result of executing tasks following the above rules will be -equivalent to simply executing tasks sequentially and in order of submission. -Of course, if dependencies are incorrectly specified, undefined behavior (and -unexpected results) may occur. - -Unlike other Dagger tasks, tasks executed within a datadeps region are allowed -to write to their arguments when annotated with `Out` or `InOut` -appropriately. - -At the end of executing `f`, `spawn_datadeps` will wait for all launched tasks -to complete, rethrowing the first error, if any. The result of `f` will be -returned from `spawn_datadeps`. - -The keyword argument `traversal` controls the order that tasks are launched by -the scheduler, and may be set to `:bfs` or `:dfs` for Breadth-First Scheduling -or Depth-First Scheduling, respectively. All traversal orders respect the -dependencies and ordering of the launched tasks, but may provide better or -worse performance for a given set of datadeps tasks. This argument is -experimental and subject to change. -""" -function spawn_datadeps(f::Base.Callable; static::Bool=true, - traversal::Symbol=:inorder, - scheduler::Union{Symbol,Nothing}=nothing, - aliasing::Bool=true, - launch_wait::Union{Bool,Nothing}=nothing) - if !static - throw(ArgumentError("Dynamic scheduling is no longer available")) - end - wait_all(; check_errors=true) do - scheduler = something(scheduler, DATADEPS_SCHEDULER[], :roundrobin)::Symbol - launch_wait = something(launch_wait, DATADEPS_LAUNCH_WAIT[], false)::Bool - if launch_wait - result = spawn_bulk() do - queue = DataDepsTaskQueue(get_options(:task_queue); - traversal, scheduler, aliasing) - with_options(f; task_queue=queue) - distribute_tasks!(queue) - end - else - queue = DataDepsTaskQueue(get_options(:task_queue); - traversal, scheduler, aliasing) - result = with_options(f; task_queue=queue) - distribute_tasks!(queue) - end - return result - end -end -const DATADEPS_SCHEDULER = ScopedValue{Union{Symbol,Nothing}}(nothing) -const DATADEPS_LAUNCH_WAIT = ScopedValue{Union{Bool,Nothing}}(nothing) diff --git a/src/datadeps/aliasing.jl b/src/datadeps/aliasing.jl new file mode 100644 index 00000000..1958fe52 --- /dev/null +++ b/src/datadeps/aliasing.jl @@ -0,0 +1,753 @@ +import Graphs: SimpleDiGraph, add_edge!, add_vertex!, inneighbors, outneighbors, nv + +export In, Out, InOut, Deps, spawn_datadeps + +#= +============================================================================== + DATADEPS ALIASING AND DATA MOVEMENT SYSTEM +============================================================================== + +This file implements the data dependencies system for Dagger tasks, which allows +tasks to write to their arguments in a controlled manner. The system maintains +data coherency across distributed workers by tracking aliasing relationships +and orchestrating data movement operations. + +OVERVIEW: +--------- +The datadeps system enables parallel execution of tasks that modify shared data +by analyzing memory aliasing relationships and scheduling appropriate data +transfers. The core challenge is maintaining coherency when aliased data (e.g., +an array and its views) needs to be accessed by tasks running on different workers. + +KEY CONCEPTS: +------------- + +1. ALIASING ANALYSIS: + - Every mutable argument is analyzed for its memory access pattern + - Memory spans are computed to determine which bytes in memory are accessed + - Objects that access overlapping memory spans are considered "aliasing" + - Examples: An array A and view(A, 2:3, 2:3) alias each other + +2. DATA LOCALITY TRACKING: + - The system tracks where the "source of truth" for each piece of data lives + - As tasks execute and modify data, the source of truth may move between workers + - Each aliasing region can have its own independent source of truth location + +3. ALIASED OBJECT MANAGEMENT: + - When copying arguments between workers, the system tracks "aliased objects" + - This ensures that if both an array and its view need to be copied to a worker, + only one copy of the underlying array is made, with the view pointing to it + - The aliased_object!() functions manage this sharing + +THE DISTRIBUTED ALIASING PROBLEM: +--------------------------------- + +In a multithreaded environment, aliasing "just works" because all tasks operate +on the same memory. However, in a distributed environment, arguments must be +copied between workers, which breaks aliasing relationships. + +Consider this scenario: +```julia +A = rand(4, 4) +vA = view(A, 2:3, 2:3) + +Dagger.spawn_datadeps() do + Dagger.@spawn inc!(InOut(A), 1) # Task 1: increment all of A + Dagger.@spawn inc!(InOut(vA), 2) # Task 2: increment view of A +end +``` + +MULTITHREADED BEHAVIOR (WORKS): +- Both tasks run on the same worker +- They operate on the same memory, with proper dependency tracking +- Task dependencies ensure correct ordering (e.g., Task 1 then Task 2) + +DISTRIBUTED BEHAVIOR (THE PROBLEM): +- Tasks may be scheduled on different workers +- Each argument must be copied to the destination worker +- Without special handling, we would copy A to worker1 and vA to worker2 +- This creates two separate arrays, breaking the aliasing relationship +- Updates to the view on worker2 don't affect the array on worker1 + +THE SOLUTION - PARTIAL DATA MOVEMENT: +------------------------------------- + +The datadeps system solves this by: + +1. UNIFIED ALLOCATION: + - When copying aliased objects, ensure only one underlying array exists per worker + - Use aliased_object!() to detect and reuse existing allocations + - Views on the destination worker point to the shared underlying array + +2. PARTIAL DATA TRANSFER: + - Instead of copying entire objects, only transfer the "dirty" regions + - This minimizes network traffic and maximizes parallelism + - Uses the move!(dep_mod, ...) function with dependency modifiers + +3. REMAINDER TRACKING: + - When a partial region is updated, track what parts still need updating + - Before a task needs the full object, copy the remaining "clean" regions + - This preserves all updates while avoiding overwrites + +EXAMPLE EXECUTION FLOW: +----------------------- + +Given: A = 4x4 array, vA = view(A, 2:3, 2:3) +Tasks: T1 modifies InOut(A), T2 modifies InOut(vA) + +1. INITIAL STATE: + - A and vA both exist on worker0 (main worker) + - A's data_locality = worker0, vA's data_locality = worker0 + +2. T1 SCHEDULED ON WORKER1: + - Copy A from worker0 to worker1 + - T1 executes, modifying all of A on worker1 + - Update: A's data_locality = worker1, A is now "dirty" on worker1 + +3. T2 SCHEDULED ON WORKER2: + - T2 needs vA, but vA aliases with A (which was modified by T1) + - Copy vA-region of A from worker1 to worker2 + - This is a PARTIAL copy - only the 2:3, 2:3 region + - Create vA on worker2 pointing to the appropriate region + - T2 executes, modifying vA region on worker2 + - Update: vA's data_locality = worker2 + +4. FINAL SYNCHRONIZATION: + - Some future task needs the complete A + - A needs to be assembled from: worker1 (non-vA regions) + worker2 (vA region) + - REMAINDER COPY: Copy non-vA regions from worker1 to worker2 + - OR INVERSE: Copy vA-region from worker2 to worker1, then copy full A + +MEMORY SPAN COMPUTATION: +------------------------ + +The system uses memory spans to determine aliasing and compute remainders: + +- ContiguousAliasing: Single contiguous memory region (e.g., full array) +- StridedAliasing: Multiple non-contiguous regions (e.g., SubArray) +- DiagonalAliasing: Diagonal elements only (e.g., Diagonal(A)) +- TriangularAliasing: Triangular regions (e.g., UpperTriangular(A)) + +Remainder computation involves: +1. Computing memory spans for all overlapping aliasing objects +2. Finding the set difference: full_object_spans - updated_spans +3. Creating a "remainder aliasing" object representing the not-yet-updated regions +4. Performing move! with this remainder object to copy only needed data + +DATA MOVEMENT FUNCTIONS: +------------------------ + +move!(dep_mod, to_space, from_space, to, from): +- The core in-place data movement function +- dep_mod specifies which part of the data to copy (identity, UpperTriangular, etc.) +- Supports partial copies via dependency modifiers + +move_rewrap(): +- Handles copying of wrapped objects (SubArrays, ChunkViews) +- Ensures aliased objects are reused on destination worker + +enqueue_copy_to!(): +- Schedules data movement tasks before user tasks +- Ensures data is up-to-date on the worker where a task will run + +CURRENT LIMITATIONS AND TODOS: +------------------------------- + +1. REMAINDER COMPUTATION: + - The system currently handles simple overlaps but needs sophisticated + remainder calculation for complex aliasing patterns + - Need functions to compute span set differences + +2. ORDERING DEPENDENCIES: + - Need to ensure remainder copies happen in correct order + - Must not overwrite more recent updates with stale data + +3. COMPLEX ALIASING PATTERNS: + - Multiple overlapping views of the same array + - Nested aliasing structures (views of views) + - Mixed aliasing types (diagonal + triangular regions) + +4. PERFORMANCE OPTIMIZATION: + - Minimize number of copy operations + - Batch compatible transfers + - Optimize for common access patterns +=# + +"Specifies a read-only dependency." +struct In{T} + x::T +end +"Specifies a write-only dependency." +struct Out{T} + x::T +end +"Specifies a read-write dependency." +struct InOut{T} + x::T +end +"Specifies one or more dependencies." +struct Deps{T,DT<:Tuple} + x::T + deps::DT +end +Deps(x, deps...) = Deps(x, deps) + +chunktype(::In{T}) where T = T +chunktype(::Out{T}) where T = T +chunktype(::InOut{T}) where T = T +chunktype(::Deps{T,DT}) where {T,DT} = T + +function unwrap_inout(arg) + readdep = false + writedep = false + if arg isa In + readdep = true + arg = arg.x + elseif arg isa Out + writedep = true + arg = arg.x + elseif arg isa InOut + readdep = true + writedep = true + arg = arg.x + elseif arg isa Deps + alldeps = Tuple[] + for dep in arg.deps + dep_mod, inner_deps = unwrap_inout(dep) + for (_, readdep, writedep) in inner_deps + push!(alldeps, (dep_mod, readdep, writedep)) + end + end + arg = arg.x + return arg, alldeps + else + readdep = true + end + return arg, Tuple[(identity, readdep, writedep)] +end + +_identity_hash(arg, h::UInt=UInt(0)) = ismutable(arg) ? objectid(arg) : hash(arg, h) +_identity_hash(arg::Chunk, h::UInt=UInt(0)) = hash(arg.handle, hash(Chunk, h)) +_identity_hash(arg::SubArray, h::UInt=UInt(0)) = hash(arg.indices, hash(arg.offset1, hash(arg.stride1, _identity_hash(arg.parent, h)))) +_identity_hash(arg::CartesianIndices, h::UInt=UInt(0)) = hash(arg.indices, hash(typeof(arg), h)) + +struct ArgumentWrapper + arg + dep_mod + hash::UInt + + function ArgumentWrapper(arg, dep_mod) + h = hash(dep_mod) + h = _identity_hash(arg, h) + return new(arg, dep_mod, h) + end +end +Base.hash(aw::ArgumentWrapper) = hash(ArgumentWrapper, aw.hash) +Base.:(==)(aw1::ArgumentWrapper, aw2::ArgumentWrapper) = + aw1.hash == aw2.hash +Base.isequal(aw1::ArgumentWrapper, aw2::ArgumentWrapper) = + aw1.hash == aw2.hash + +struct HistoryEntry + ainfo::AliasingWrapper + space::MemorySpace + write_num::Int +end + +struct DataDepsState + # The mapping of original raw argument to its Chunk + raw_arg_to_chunk::IdDict{Any,Chunk} + + # The origin memory space of each argument + # Used to track the original location of an argument, for final copy-from + arg_origin::IdDict{Any,MemorySpace} + + # The mapping of memory space to argument to remote argument copies + # Used to replace an argument with its remote copy + remote_args::Dict{MemorySpace,IdDict{Any,Chunk}} + + # The mapping of remote argument to original argument + remote_arg_to_original::IdDict{Any,Any} + + # The mapping of original argument wrapper to remote argument wrapper + remote_arg_w::Dict{ArgumentWrapper,Dict{MemorySpace,ArgumentWrapper}} + + # The mapping of ainfo to argument and dep_mod + # Used to lookup which argument and dep_mod a given ainfo is generated from + # N.B. This is a mapping for remote argument copies + ainfo_arg::Dict{AliasingWrapper,ArgumentWrapper} + + # The history of writes (direct or indirect) to each argument and dep_mod, in terms of ainfos directly written to, and the memory space they were written to + # Updated when a new write happens on an overlapping ainfo + # Used by remainder copies to track which portions of an argument and dep_mod were written to elsewhere, through another argument + arg_history::Dict{ArgumentWrapper,Vector{HistoryEntry}} + + # The mapping of memory space and argument to the memory space of the last direct write + # Used by remainder copies to lookup the "backstop" if any portion of the target ainfo is not updated by the remainder + arg_owner::Dict{ArgumentWrapper,MemorySpace} + + # The overlap of each argument with every other argument, based on the ainfo overlaps + # Incrementally updated as new ainfos are created + # Used for fast history updates + arg_overlaps::Dict{ArgumentWrapper,Set{ArgumentWrapper}} + + # The mapping of, for a given memory space, the backing Chunks that an ainfo references + # Used by slot generation to replace the backing Chunks during move + ainfo_backing_chunk::Dict{MemorySpace,Dict{AbstractAliasing,Chunk}} + + # Cache of argument's supports_inplace_move query result + supports_inplace_cache::IdDict{Any,Bool} + + # Cache of argument and dep_mod to ainfo + # N.B. This is a mapping for remote argument copies + ainfo_cache::Dict{ArgumentWrapper,AliasingWrapper} + + # The oracle for aliasing lookups + # Used to populate ainfos_overlaps efficiently + ainfos_lookup::AliasingLookup + + # The overlapping ainfos for each ainfo + # Incrementally updated as new ainfos are created + # Used for fast will_alias lookups + ainfos_overlaps::Dict{AliasingWrapper,Set{AliasingWrapper}} + + # Track writers ("owners") and readers + # Updated as new writer and reader tasks are launched + # Used by task dependency tracking to calculate syncdeps and ensure correct launch ordering + ainfos_owner::Dict{AliasingWrapper,Union{Pair{DTask,Int},Nothing}} + ainfos_readers::Dict{AliasingWrapper,Vector{Pair{DTask,Int}}} + + function DataDepsState(aliasing::Bool) + if !aliasing + @warn "aliasing=false is no longer supported, aliasing is now always enabled" maxlog=1 + end + + arg_to_chunk = IdDict{Any,Chunk}() + arg_origin = IdDict{Any,MemorySpace}() + remote_args = Dict{MemorySpace,IdDict{Any,Any}}() + remote_arg_to_original = IdDict{Any,Any}() + remote_arg_w = Dict{ArgumentWrapper,Dict{MemorySpace,ArgumentWrapper}}() + ainfo_arg = Dict{AliasingWrapper,ArgumentWrapper}() + arg_owner = Dict{ArgumentWrapper,MemorySpace}() + arg_overlaps = Dict{ArgumentWrapper,Set{ArgumentWrapper}}() + ainfo_backing_chunk = Dict{MemorySpace,Dict{AbstractAliasing,Chunk}}() + arg_history = Dict{ArgumentWrapper,Vector{HistoryEntry}}() + + supports_inplace_cache = IdDict{Any,Bool}() + ainfo_cache = Dict{ArgumentWrapper,AliasingWrapper}() + + ainfos_lookup = AliasingLookup() + ainfos_overlaps = Dict{AliasingWrapper,Set{AliasingWrapper}}() + + ainfos_owner = Dict{AliasingWrapper,Union{Pair{DTask,Int},Nothing}}() + ainfos_readers = Dict{AliasingWrapper,Vector{Pair{DTask,Int}}}() + + return new(arg_to_chunk, arg_origin, remote_args, remote_arg_to_original, remote_arg_w, ainfo_arg, arg_owner, arg_overlaps, ainfo_backing_chunk, arg_history, + supports_inplace_cache, ainfo_cache, ainfos_lookup, ainfos_overlaps, ainfos_owner, ainfos_readers) + end +end + +function supports_inplace_move(state::DataDepsState, arg) + return get!(state.supports_inplace_cache, arg) do + return supports_inplace_move(arg) + end +end + +# Determine which arguments could be written to, and thus need tracking +"Whether `arg` is written to by `task`." +function is_writedep(arg, deps, task::DTask) + return any(dep->dep[3], deps) +end + +# Aliasing state setup +function populate_task_info!(state::DataDepsState, task_args, spec::DTaskSpec, task::DTask) + # Track the task's arguments and access patterns + return map_or_ntuple(task_args) do idx + _arg = task_args[idx] + + # Unwrap the argument + _arg_with_deps = value(_arg) + pos = _arg.pos + + # Unwrap In/InOut/Out wrappers and record dependencies + arg_pre_unwrap, deps = unwrap_inout(_arg_with_deps) + + # Unwrap the Chunk underlying any DTask arguments + arg = arg_pre_unwrap isa DTask ? fetch(arg_pre_unwrap; raw=true) : arg_pre_unwrap + + # Skip non-aliasing arguments or arguments that don't support in-place move + may_alias = type_may_alias(typeof(arg)) + inplace_move = may_alias && supports_inplace_move(state, arg) + if !may_alias || !inplace_move + arg_w = ArgumentWrapper(arg, identity) + if is_typed(spec) + return TypedDataDepsTaskArgument(arg, pos, may_alias, inplace_move, (DataDepsTaskDependency(arg_w, false, false),)) + else + return DataDepsTaskArgument(arg, pos, may_alias, inplace_move, [DataDepsTaskDependency(arg_w, false, false)]) + end + end + + # Generate a Chunk for the argument if necessary + if haskey(state.raw_arg_to_chunk, arg) + arg_chunk = state.raw_arg_to_chunk[arg] + else + if !(arg isa Chunk) + arg_chunk = tochunk(arg) + state.raw_arg_to_chunk[arg] = arg_chunk + else + state.raw_arg_to_chunk[arg] = arg + arg_chunk = arg + end + end + + # Track the origin space of the argument + origin_space = memory_space(arg_chunk) + state.arg_origin[arg_chunk] = origin_space + state.remote_arg_to_original[arg_chunk] = arg_chunk + + # Populate argument info for all aliasing dependencies + # And return the argument, dependencies, and ArgumentWrappers + if is_typed(spec) + deps = Tuple(DataDepsTaskDependency(arg_chunk, dep) for dep in deps) + map_or_ntuple(deps) do dep_idx + dep = deps[dep_idx] + # Populate argument info + populate_argument_info!(state, dep.arg_w, origin_space) + end + return TypedDataDepsTaskArgument(arg_chunk, pos, may_alias, inplace_move, deps) + else + deps = [DataDepsTaskDependency(arg_chunk, dep) for dep in deps] + map_or_ntuple(deps) do dep_idx + dep = deps[dep_idx] + # Populate argument info + populate_argument_info!(state, dep.arg_w, origin_space) + end + return DataDepsTaskArgument(arg_chunk, pos, may_alias, inplace_move, deps) + end + end +end +function populate_argument_info!(state::DataDepsState, arg_w::ArgumentWrapper, origin_space::MemorySpace) + # Initialize ownership and history + if !haskey(state.arg_owner, arg_w) + # N.B. This is valid (even if the backing data is up-to-date elsewhere), + # because we only use this to track the "backstop" if any portion of the + # target ainfo is not updated by the remainder (at which point, this + # is thus the correct owner). + state.arg_owner[arg_w] = origin_space + + # Initialize the overlap set + state.arg_overlaps[arg_w] = Set{ArgumentWrapper}() + end + if !haskey(state.arg_history, arg_w) + state.arg_history[arg_w] = Vector{HistoryEntry}() + end + + # Calculate the ainfo (which will populate ainfo structures and merge history) + aliasing!(state, origin_space, arg_w) +end +# N.B. arg_w must be the original argument wrapper, not a remote copy +function aliasing!(state::DataDepsState, target_space::MemorySpace, arg_w::ArgumentWrapper) + if haskey(state.remote_arg_w, arg_w) && haskey(state.remote_arg_w[arg_w], target_space) + remote_arg_w = @inbounds state.remote_arg_w[arg_w][target_space] + remote_arg = remote_arg_w.arg + else + # Grab the remote copy of the argument, and calculate the ainfo + remote_arg = get_or_generate_slot!(state, target_space, arg_w.arg) + remote_arg_w = ArgumentWrapper(remote_arg, arg_w.dep_mod) + get!(Dict{MemorySpace,ArgumentWrapper}, state.remote_arg_w, arg_w)[target_space] = remote_arg_w + end + + # Check if we already have the result cached + if haskey(state.ainfo_cache, remote_arg_w) + return state.ainfo_cache[remote_arg_w] + end + + # Calculate the ainfo + ainfo = AliasingWrapper(aliasing(remote_arg, arg_w.dep_mod)) + + # Cache the result + state.ainfo_cache[remote_arg_w] = ainfo + + # Update the mapping of ainfo to argument and dep_mod + if !haskey(state.ainfo_arg, ainfo) + state.ainfo_arg[ainfo] = remote_arg_w + else + @assert state.ainfo_arg[ainfo] == remote_arg_w + end + + # Populate info for the new ainfo + populate_ainfo!(state, arg_w, ainfo, target_space) + + return ainfo +end +function populate_ainfo!(state::DataDepsState, original_arg_w::ArgumentWrapper, target_ainfo::AliasingWrapper, target_space::MemorySpace) + if !haskey(state.ainfos_owner, target_ainfo) + # Add ourselves to the lookup oracle + ainfo_idx = push!(state.ainfos_lookup, target_ainfo) + + # Find overlapping ainfos + overlaps = Set{AliasingWrapper}() + push!(overlaps, target_ainfo) + for other_ainfo in intersect(state.ainfos_lookup, target_ainfo; ainfo_idx) + target_ainfo == other_ainfo && continue + # Mark us and them as overlapping + push!(overlaps, other_ainfo) + push!(state.ainfos_overlaps[other_ainfo], target_ainfo) + + # Add overlapping history to our own + other_remote_arg_w = state.ainfo_arg[other_ainfo] + other_arg = state.remote_arg_to_original[other_remote_arg_w.arg] + other_arg_w = ArgumentWrapper(other_arg, other_remote_arg_w.dep_mod) + push!(state.arg_overlaps[original_arg_w], other_arg_w) + push!(state.arg_overlaps[other_arg_w], original_arg_w) + merge_history!(state, original_arg_w, other_arg_w) + end + state.ainfos_overlaps[target_ainfo] = overlaps + + # Initialize owner and readers + state.ainfos_owner[target_ainfo] = nothing + state.ainfos_readers[target_ainfo] = Pair{DTask,Int}[] + end +end +function merge_history!(state::DataDepsState, arg_w::ArgumentWrapper, other_arg_w::ArgumentWrapper) + history = state.arg_history[arg_w] + @opcounter :merge_history + @opcounter :merge_history_complexity length(history) + origin_space = state.arg_origin[other_arg_w.arg] + for other_entry in state.arg_history[other_arg_w] + write_num_tuple = HistoryEntry(AliasingWrapper(NoAliasing()), origin_space, other_entry.write_num) + range = searchsorted(history, write_num_tuple; by=x->x.write_num) + if !isempty(range) + # Find and skip duplicates + match = false + for source_idx in range + source_entry = history[source_idx] + if source_entry.ainfo == other_entry.ainfo && + source_entry.space == other_entry.space && + source_entry.write_num == other_entry.write_num + match = true + break + end + end + match && continue + + # Insert at the first position + idx = first(range) + else + # Insert at the last position + idx = length(history) + 1 + end + insert!(history, idx, other_entry) + end +end +function truncate_history!(state::DataDepsState, arg_w::ArgumentWrapper) + # FIXME: Do this continuously if possible + if haskey(state.arg_history, arg_w) && length(state.arg_history[arg_w]) > 100000 + origin_space = state.arg_origin[arg_w.arg] + @opcounter :truncate_history + _, last_idx = compute_remainder_for_arg!(state, origin_space, arg_w, 0; compute_syncdeps=false) + if last_idx > 0 + @opcounter :truncate_history_removed last_idx + deleteat!(state.arg_history[arg_w], 1:last_idx) + end + end +end + +""" + supports_inplace_move(x) -> Bool + +Returns `false` if `x` doesn't support being copied into from another object +like `x`, via `move!`. This is used in `spawn_datadeps` to prevent attempting +to copy between values which don't support mutation or otherwise don't have an +implemented `move!` and want to skip in-place copies. When this returns +`false`, datadeps will instead perform out-of-place copies for each non-local +use of `x`, and the data in `x` will not be updated when the `spawn_datadeps` +region returns. +""" +supports_inplace_move(x) = true +supports_inplace_move(t::DTask) = supports_inplace_move(fetch(t; raw=true)) +function supports_inplace_move(c::Chunk) + # FIXME: Use MemPool.access_ref + pid = root_worker_id(c.processor) + if pid == myid() + return supports_inplace_move(poolget(c.handle)) + else + return remotecall_fetch(supports_inplace_move, pid, c) + end +end +supports_inplace_move(::Function) = false + +# Read/write dependency management +function get_write_deps!(state::DataDepsState, dest_space::MemorySpace, ainfo::AbstractAliasing, write_num, syncdeps) + # We need to sync with both writers and readers + _get_write_deps!(state, dest_space, ainfo, write_num, syncdeps) + _get_read_deps!(state, dest_space, ainfo, write_num, syncdeps) +end +function get_read_deps!(state::DataDepsState, dest_space::MemorySpace, ainfo::AbstractAliasing, write_num, syncdeps) + # We only need to sync with writers, not readers + _get_write_deps!(state, dest_space, ainfo, write_num, syncdeps) +end + +function _get_write_deps!(state::DataDepsState, dest_space::MemorySpace, ainfo::AbstractAliasing, write_num, syncdeps) + ainfo.inner isa NoAliasing && return + for other_ainfo in state.ainfos_overlaps[ainfo] + other_task_write_num = state.ainfos_owner[other_ainfo] + @dagdebug nothing :spawn_datadeps_sync "Considering sync with writer via $ainfo -> $other_ainfo" + other_task_write_num === nothing && continue + other_task, other_write_num = other_task_write_num + write_num == other_write_num && continue + @dagdebug nothing :spawn_datadeps_sync "Sync with writer via $ainfo -> $other_ainfo" + push!(syncdeps, ThunkSyncdep(other_task)) + end +end +function _get_read_deps!(state::DataDepsState, dest_space::MemorySpace, ainfo::AbstractAliasing, write_num, syncdeps) + ainfo.inner isa NoAliasing && return + for other_ainfo in state.ainfos_overlaps[ainfo] + @dagdebug nothing :spawn_datadeps_sync "Considering sync with reader via $ainfo -> $other_ainfo" + other_tasks = state.ainfos_readers[other_ainfo] + for (other_task, other_write_num) in other_tasks + write_num == other_write_num && continue + @dagdebug nothing :spawn_datadeps_sync "Sync with reader via $ainfo -> $other_ainfo" + push!(syncdeps, ThunkSyncdep(other_task)) + end + end +end +function add_writer!(state::DataDepsState, arg_w::ArgumentWrapper, dest_space::MemorySpace, ainfo::AbstractAliasing, task, write_num) + state.ainfos_owner[ainfo] = task=>write_num + empty!(state.ainfos_readers[ainfo]) + + # Clear the history for this target, since this is a new write event + empty!(state.arg_history[arg_w]) + + # Add our own history + push!(state.arg_history[arg_w], HistoryEntry(ainfo, dest_space, write_num)) + + # Find overlapping arguments and update their history + for other_arg_w in state.arg_overlaps[arg_w] + other_arg_w == arg_w && continue + push!(state.arg_history[other_arg_w], HistoryEntry(ainfo, dest_space, write_num)) + end + + # Record the last place we were fully written to + state.arg_owner[arg_w] = dest_space + + # Not necessary to assert a read, but conceptually it's true + add_reader!(state, arg_w, dest_space, ainfo, task, write_num) +end +function add_reader!(state::DataDepsState, arg_w::ArgumentWrapper, dest_space::MemorySpace, ainfo::AbstractAliasing, task, write_num) + push!(state.ainfos_readers[ainfo], task=>write_num) +end + +# Make a copy of each piece of data on each worker +# memory_space => {arg => copy_of_arg} +isremotehandle(x) = false +isremotehandle(x::DTask) = true +isremotehandle(x::Chunk) = true +function generate_slot!(state::DataDepsState, dest_space, data) + if data isa DTask + data = fetch(data; raw=true) + end + # N.B. We do not perform any sync/copy with the current owner of the data, + # because all we want here is to make a copy of some version of the data, + # even if the data is not up to date. + orig_space = memory_space(data) + to_proc = first(processors(dest_space)) + from_proc = first(processors(orig_space)) + dest_space_args = get!(IdDict{Any,Any}, state.remote_args, dest_space) + ALIASED_OBJECT_CACHE[] = get!(Dict{AbstractAliasing,Chunk}, state.ainfo_backing_chunk, dest_space) + if orig_space == dest_space && (data isa Chunk || !isremotehandle(data)) + # Fast path for local data that's already in a Chunk or not a remote handle needing rewrapping + data_chunk = tochunk(data, from_proc) + else + ctx = Sch.eager_context() + id = rand(Int) + @maybelog ctx timespan_start(ctx, :move, (;thunk_id=0, id, position=ArgPosition(), processor=to_proc), (;f=nothing, data)) + data_chunk = move_rewrap(from_proc, to_proc, orig_space, dest_space, data) + @maybelog ctx timespan_finish(ctx, :move, (;thunk_id=0, id, position=ArgPosition(), processor=to_proc), (;f=nothing, data=data_chunk)) + end + @assert memory_space(data_chunk) == dest_space "space mismatch! $dest_space (dest) != $(memory_space(data_chunk)) (actual) ($(typeof(data)) (data) vs. $(typeof(data_chunk)) (chunk)), spaces ($orig_space -> $dest_space)" + dest_space_args[data] = data_chunk + state.remote_arg_to_original[data_chunk] = data + + ALIASED_OBJECT_CACHE[] = nothing + + return dest_space_args[data] +end +function get_or_generate_slot!(state, dest_space, data) + @assert !(data isa ArgumentWrapper) + if !haskey(state.remote_args, dest_space) + state.remote_args[dest_space] = IdDict{Any,Any}() + end + if !haskey(state.remote_args[dest_space], data) + return generate_slot!(state, dest_space, data) + end + return state.remote_args[dest_space][data] +end +function move_rewrap(from_proc::Processor, to_proc::Processor, from_space::MemorySpace, to_space::MemorySpace, data) + return aliased_object!(data) do data + return remotecall_endpoint(identity, from_proc, to_proc, from_space, to_space, data) + end +end +function remotecall_endpoint(f, from_proc, to_proc, orig_space, dest_space, data) + to_w = root_worker_id(to_proc) + if to_w == myid() + data_converted = f(move(from_proc, to_proc, data)) + return tochunk(data_converted, to_proc, dest_space) + end + return remotecall_fetch(to_w, from_proc, to_proc, dest_space, data) do from_proc, to_proc, dest_space, data + data_converted = f(move(from_proc, to_proc, data)) + return tochunk(data_converted, to_proc, dest_space) + end +end +const ALIASED_OBJECT_CACHE = TaskLocalValue{Union{Dict{AbstractAliasing,Chunk}, Nothing}}(()->nothing) +@warn "Document these public methods" maxlog=1 +# TODO: Use state to cache aliasing() results +function declare_aliased_object!(x; ainfo=aliasing(x, identity)) + cache = ALIASED_OBJECT_CACHE[] + cache[ainfo] = x +end +function aliased_object!(x; ainfo=aliasing(x, identity)) + cache = ALIASED_OBJECT_CACHE[] + if haskey(cache, ainfo) + y = cache[ainfo] + else + @assert x isa Chunk "x must be a Chunk\nUse functor form of aliased_object!" + cache[ainfo] = x + y = x + end + return y +end +function aliased_object!(f, x; ainfo=aliasing(x, identity)) + cache = ALIASED_OBJECT_CACHE[] + if haskey(cache, ainfo) + y = cache[ainfo] + else + y = f(x) + @assert y isa Chunk "Didn't get a Chunk from functor" + cache[ainfo] = y + end + return y +end +function aliased_object_unwrap!(x::Chunk) + y = unwrap(x) + ainfo = aliasing(y, identity) + return unwrap(aliased_object!(x; ainfo)) +end + +struct DataDepsSchedulerState + task_to_spec::Dict{DTask,DTaskSpec} + assignments::Dict{DTask,MemorySpace} + dependencies::Dict{DTask,Set{DTask}} + task_completions::Dict{DTask,UInt64} + space_completions::Dict{MemorySpace,UInt64} + capacities::Dict{MemorySpace,Int} + + function DataDepsSchedulerState() + return new(Dict{DTask,DTaskSpec}(), + Dict{DTask,MemorySpace}(), + Dict{DTask,Set{DTask}}(), + Dict{DTask,UInt64}(), + Dict{MemorySpace,UInt64}(), + Dict{MemorySpace,Int}()) + end +end \ No newline at end of file diff --git a/src/datadeps/chunkview.jl b/src/datadeps/chunkview.jl new file mode 100644 index 00000000..04b581c1 --- /dev/null +++ b/src/datadeps/chunkview.jl @@ -0,0 +1,64 @@ +struct ChunkView{N} + chunk::Chunk + slices::NTuple{N, Union{Int, AbstractRange{Int}, Colon}} +end + +function Base.view(c::Chunk, slices...) + if c.domain isa ArrayDomain + nd, sz = ndims(c.domain), size(c.domain) + nd == length(slices) || throw(DimensionMismatch("Expected $nd slices, got $(length(slices))")) + + for (i, s) in enumerate(slices) + if s isa Int + 1 ≤ s ≤ sz[i] || throw(ArgumentError("Index $s out of bounds for dimension $i (size $(sz[i]))")) + elseif s isa AbstractRange + isempty(s) && continue + 1 ≤ first(s) ≤ last(s) ≤ sz[i] || throw(ArgumentError("Range $s out of bounds for dimension $i (size $(sz[i]))")) + elseif s === Colon() + continue + else + throw(ArgumentError("Invalid slice type $(typeof(s)) at dimension $i, Expected Type of Int, AbstractRange, or Colon")) + end + end + end + + return ChunkView(c, slices) +end + +Base.view(c::DTask, slices...) = view(fetch(c; raw=true), slices...) + +aliasing(x::ChunkView) = + throw(ConcurrencyViolationError("Cannot query aliasing of a ChunkView directly")) +memory_space(x::ChunkView) = memory_space(x.chunk) +isremotehandle(x::ChunkView) = true + +# This definition is here because it's so similar to ChunkView +function move_rewrap(from_proc::Processor, to_proc::Processor, v::SubArray) + to_w = root_worker_id(to_proc) + p_chunk = aliased_object!(parent(v)) do p + return remotecall_fetch(to_w, from_proc, to_proc, p) do from_proc, to_proc, p + return tochunk(move(from_proc, to_proc, p), to_proc) + end + end + inds = parentindices(v) + return remotecall_fetch(to_w, from_proc, to_proc, p_chunk, inds) do from_proc, to_proc, p_chunk, inds + p_new = move(from_proc, to_proc, p_chunk) + v_new = view(p_new, inds...) + return tochunk(v_new, to_proc) + end +end +function move_rewrap(from_proc::Processor, to_proc::Processor, slice::ChunkView) + to_w = root_worker_id(to_proc) + p_chunk = aliased_object!(slice.chunk) do p_chunk + return remotecall_fetch(to_w, from_proc, to_proc, p_chunk) do from_proc, to_proc, p_chunk + return tochunk(move(from_proc, to_proc, p_chunk), to_proc) + end + end + return remotecall_fetch(to_w, from_proc, to_proc, p_chunk, slice.slices) do from_proc, to_proc, p_chunk, inds + p_new = move(from_proc, to_proc, p_chunk) + v_new = view(p_new, inds...) + return tochunk(v_new, to_proc) + end +end + +Base.fetch(slice::ChunkView) = view(fetch(slice.chunk), slice.slices...) \ No newline at end of file diff --git a/src/datadeps/queue.jl b/src/datadeps/queue.jl new file mode 100644 index 00000000..a0c8fc95 --- /dev/null +++ b/src/datadeps/queue.jl @@ -0,0 +1,549 @@ +struct DataDepsTaskQueue <: AbstractTaskQueue + # The queue above us + upper_queue::AbstractTaskQueue + # The set of tasks that have already been seen + seen_tasks::Union{Vector{DTaskPair},Nothing} + # The data-dependency graph of all tasks + g::Union{SimpleDiGraph{Int},Nothing} + # The mapping from task to graph ID + task_to_id::Union{Dict{DTask,Int},Nothing} + # How to traverse the dependency graph when launching tasks + traversal::Symbol + # Which scheduler to use to assign tasks to processors + scheduler::Symbol + + # Whether aliasing across arguments is possible + # The fields following only apply when aliasing==true + aliasing::Bool + + function DataDepsTaskQueue(upper_queue; + traversal::Symbol=:inorder, + scheduler::Symbol=:naive, + aliasing::Bool=true) + seen_tasks = DTaskPair[] + g = SimpleDiGraph() + task_to_id = Dict{DTask,Int}() + return new(upper_queue, seen_tasks, g, task_to_id, traversal, scheduler, + aliasing) + end +end + +function enqueue!(queue::DataDepsTaskQueue, pair::DTaskPair) + push!(queue.seen_tasks, pair) +end +function enqueue!(queue::DataDepsTaskQueue, pairs::Vector{DTaskPair}) + append!(queue.seen_tasks, pairs) +end + +""" + spawn_datadeps(f::Base.Callable; traversal::Symbol=:inorder) + +Constructs a "datadeps" (data dependencies) region and calls `f` within it. +Dagger tasks launched within `f` may wrap their arguments with `In`, `Out`, or +`InOut` to indicate whether the task will read, write, or read+write that +argument, respectively. These argument dependencies will be used to specify +which tasks depend on each other based on the following rules: + +- Dependencies across unrelated arguments are independent; only dependencies on arguments which overlap in memory synchronize with each other +- `InOut` is the same as `In` and `Out` applied simultaneously, and synchronizes with the union of the `In` and `Out` effects +- Any two or more `In` dependencies do not synchronize with each other, and may execute in parallel +- An `Out` dependency synchronizes with any previous `In` and `Out` dependencies +- An `In` dependency synchronizes with any previous `Out` dependencies +- If unspecified, an `In` dependency is assumed + +In general, the result of executing tasks following the above rules will be +equivalent to simply executing tasks sequentially and in order of submission. +Of course, if dependencies are incorrectly specified, undefined behavior (and +unexpected results) may occur. + +Unlike other Dagger tasks, tasks executed within a datadeps region are allowed +to write to their arguments when annotated with `Out` or `InOut` +appropriately. + +At the end of executing `f`, `spawn_datadeps` will wait for all launched tasks +to complete, rethrowing the first error, if any. The result of `f` will be +returned from `spawn_datadeps`. + +The keyword argument `traversal` controls the order that tasks are launched by +the scheduler, and may be set to `:bfs` or `:dfs` for Breadth-First Scheduling +or Depth-First Scheduling, respectively. All traversal orders respect the +dependencies and ordering of the launched tasks, but may provide better or +worse performance for a given set of datadeps tasks. This argument is +experimental and subject to change. +""" +function spawn_datadeps(f::Base.Callable; static::Bool=true, + traversal::Symbol=:inorder, + scheduler::Union{Symbol,Nothing}=nothing, + aliasing::Bool=true, + launch_wait::Union{Bool,Nothing}=nothing) + if !static + throw(ArgumentError("Dynamic scheduling is no longer available")) + end + wait_all(; check_errors=true) do + scheduler = something(scheduler, DATADEPS_SCHEDULER[], :roundrobin)::Symbol + launch_wait = something(launch_wait, DATADEPS_LAUNCH_WAIT[], false)::Bool + if launch_wait + result = spawn_bulk() do + queue = DataDepsTaskQueue(get_options(:task_queue); + traversal, scheduler, aliasing) + with_options(f; task_queue=queue) + distribute_tasks!(queue) + end + else + queue = DataDepsTaskQueue(get_options(:task_queue); + traversal, scheduler, aliasing) + result = with_options(f; task_queue=queue) + distribute_tasks!(queue) + end + return result + end +end +const DATADEPS_SCHEDULER = ScopedValue{Union{Symbol,Nothing}}(nothing) +const DATADEPS_LAUNCH_WAIT = ScopedValue{Union{Bool,Nothing}}(nothing) + +function distribute_tasks!(queue::DataDepsTaskQueue) + #= TODO: Improvements to be made: + # - Support for copying non-AbstractArray arguments + # - Parallelize read copies + # - Unreference unused slots + # - Reuse memory when possible + # - Account for differently-sized data + =# + + # Get the set of all processors to be scheduled on + all_procs = Processor[] + scope = get_compute_scope() + for w in procs() + append!(all_procs, get_processors(OSProc(w))) + end + filter!(proc->proc_in_scope(proc, scope), all_procs) + if isempty(all_procs) + throw(Sch.SchedulingException("No processors available, try widening scope")) + end + scope = UnionScope(map(ExactScope, all_procs)) + exec_spaces = unique(vcat(map(proc->collect(memory_spaces(proc)), all_procs)...)) + if !all(space->space isa CPURAMMemorySpace, exec_spaces) && !all(space->root_worker_id(space) == myid(), exec_spaces) + @warn "Datadeps support for multi-GPU, multi-worker is currently broken\nPlease be prepared for incorrect results or errors" maxlog=1 + end + + # Round-robin assign tasks to processors + upper_queue = get_options(:task_queue) + + traversal = queue.traversal + if traversal == :inorder + # As-is + task_order = Colon() + elseif traversal == :bfs + # BFS + task_order = Int[1] + to_walk = Int[1] + seen = Set{Int}([1]) + while !isempty(to_walk) + # N.B. next_root has already been seen + next_root = popfirst!(to_walk) + for v in outneighbors(queue.g, next_root) + if !(v in seen) + push!(task_order, v) + push!(seen, v) + push!(to_walk, v) + end + end + end + elseif traversal == :dfs + # DFS (modified with backtracking) + task_order = Int[] + to_walk = Int[1] + seen = Set{Int}() + while length(task_order) < length(queue.seen_tasks) && !isempty(to_walk) + next_root = popfirst!(to_walk) + if !(next_root in seen) + iv = inneighbors(queue.g, next_root) + if all(v->v in seen, iv) + push!(task_order, next_root) + push!(seen, next_root) + ov = outneighbors(queue.g, next_root) + prepend!(to_walk, ov) + else + push!(to_walk, next_root) + end + end + end + else + throw(ArgumentError("Invalid traversal mode: $traversal")) + end + + state = DataDepsState(queue.aliasing) + sstate = DataDepsSchedulerState() + for proc in all_procs + space = only(memory_spaces(proc)) + get!(()->0, sstate.capacities, space) + sstate.capacities[space] += 1 + end + + # Start launching tasks and necessary copies + write_num = 1 + proc_idx = 1 + #pressures = Dict{Processor,Int}() + proc_to_scope_lfu = BasicLFUCache{Processor,AbstractScope}(1024) + for pair in queue.seen_tasks[task_order] + spec = pair.spec + task = pair.task + write_num, proc_idx = distribute_task!(queue, state, all_procs, scope, spec, task, spec.fargs, proc_to_scope_lfu, write_num, proc_idx) + end + + # Copy args from remote to local + # N.B. We sort the keys to ensure a deterministic order for uniformity + for arg_w in sort(collect(keys(state.arg_owner)); by=arg_w->arg_w.hash) + arg = arg_w.arg + origin_space = state.arg_origin[arg] + remainder, _ = compute_remainder_for_arg!(state, origin_space, arg_w, write_num) + if remainder isa MultiRemainderAliasing + origin_scope = UnionScope(map(ExactScope, collect(processors(origin_space)))...) + enqueue_remainder_copy_from!(state, origin_space, arg_w, remainder, origin_scope, write_num) + elseif remainder isa FullCopy + origin_scope = UnionScope(map(ExactScope, collect(processors(origin_space)))...) + enqueue_copy_from!(state, origin_space, arg_w, origin_scope, write_num) + else + @assert remainder isa NoAliasing "Expected NoAliasing, got $(typeof(remainder))" + @dagdebug nothing :spawn_datadeps "Skipped copy-from (up-to-date): $origin_space" + end + end +end +struct DataDepsTaskDependency + arg_w::ArgumentWrapper + readdep::Bool + writedep::Bool +end +DataDepsTaskDependency(arg, dep) = + DataDepsTaskDependency(ArgumentWrapper(arg, dep[1]), dep[2], dep[3]) +struct DataDepsTaskArgument + arg + pos::ArgPosition + may_alias::Bool + inplace_move::Bool + deps::Vector{DataDepsTaskDependency} +end +struct TypedDataDepsTaskArgument{T,N} + arg::T + pos::ArgPosition + may_alias::Bool + inplace_move::Bool + deps::NTuple{N,DataDepsTaskDependency} +end +map_or_ntuple(f, xs::Vector) = map(f, 1:length(xs)) +map_or_ntuple(f, xs::Tuple) = ntuple(f, length(xs)) +function distribute_task!(queue::DataDepsTaskQueue, state::DataDepsState, all_procs, scope, spec::DTaskSpec{typed}, task::DTask, fargs, proc_to_scope_lfu, write_num::Int, proc_idx::Int) where typed + @specialize spec fargs + + if typed + fargs::Tuple + else + fargs::Vector{Argument} + end + + task_scope = @something(spec.options.compute_scope, spec.options.scope, DefaultScope()) + scheduler = queue.scheduler + if scheduler == :naive + raw_args = map(arg->tochunk(value(arg)), spec.fargs) + our_proc = remotecall_fetch(1, all_procs, raw_args) do all_procs, raw_args + Sch.init_eager() + sch_state = Sch.EAGER_STATE[] + + @lock sch_state.lock begin + # Calculate costs per processor and select the most optimal + # FIXME: This should consider any already-allocated slots, + # whether they are up-to-date, and if not, the cost of moving + # data to them + procs, costs = Sch.estimate_task_costs(sch_state, all_procs, nothing, raw_args) + return first(procs) + end + end + elseif scheduler == :smart + raw_args = map(filter(arg->haskey(state.data_locality, value(arg)), spec.fargs)) do arg + arg_chunk = tochunk(value(arg)) + # Only the owned slot is valid + # FIXME: Track up-to-date copies and pass all of those + return arg_chunk => data_locality[arg] + end + f_chunk = tochunk(value(spec.fargs[1])) + our_proc, task_pressure = remotecall_fetch(1, all_procs, pressures, f_chunk, raw_args) do all_procs, pressures, f, chunks_locality + Sch.init_eager() + sch_state = Sch.EAGER_STATE[] + + @lock sch_state.lock begin + tx_rate = sch_state.transfer_rate[] + + costs = Dict{Processor,Float64}() + for proc in all_procs + # Filter out chunks that are already local + chunks_filt = Iterators.filter(((chunk, space)=chunk_locality)->!(proc in processors(space)), chunks_locality) + + # Estimate network transfer costs based on data size + # N.B. `affinity(x)` really means "data size of `x`" + # N.B. We treat same-worker transfers as having zero transfer cost + tx_cost = Sch.impute_sum(affinity(chunk)[2] for chunk in chunks_filt) + + # Estimate total cost to move data and get task running after currently-scheduled tasks + est_time_util = get(pressures, proc, UInt64(0)) + costs[proc] = est_time_util + (tx_cost/tx_rate) + end + + # Look up estimated task cost + sig = Sch.signature(sch_state, f, map(first, chunks_locality)) + task_pressure = get(sch_state.signature_time_cost, sig, 1000^3) + + # Shuffle procs around, so equally-costly procs are equally considered + P = randperm(length(all_procs)) + procs = getindex.(Ref(all_procs), P) + + # Sort by lowest cost first + sort!(procs, by=p->costs[p]) + + best_proc = first(procs) + return best_proc, task_pressure + end + end + # FIXME: Pressure should be decreased by pressure of syncdeps on same processor + pressures[our_proc] = get(pressures, our_proc, UInt64(0)) + task_pressure + elseif scheduler == :ultra + args = Base.mapany(spec.fargs) do arg + pos, data = arg + data, _ = unwrap_inout(data) + if data isa DTask + data = fetch(data; move_value=false, unwrap=false) + end + return pos => tochunk(data) + end + f_chunk = tochunk(value(spec.fargs[1])) + task_time = remotecall_fetch(1, f_chunk, args) do f, args + Sch.init_eager() + sch_state = Sch.EAGER_STATE[] + return @lock sch_state.lock begin + sig = Sch.signature(sch_state, f, args) + return get(sch_state.signature_time_cost, sig, 1000^3) + end + end + + # FIXME: Copy deps are computed eagerly + deps = @something(spec.options.syncdeps, Set{Any}()) + + # Find latest time-to-completion of all syncdeps + deps_completed = UInt64(0) + for dep in deps + haskey(sstate.task_completions, dep) || continue # copy deps aren't recorded + deps_completed = max(deps_completed, sstate.task_completions[dep]) + end + + # Find latest time-to-completion of each memory space + # FIXME: Figure out space completions based on optimal packing + spaces_completed = Dict{MemorySpace,UInt64}() + for space in exec_spaces + completed = UInt64(0) + for (task, other_space) in sstate.assignments + space == other_space || continue + completed = max(completed, sstate.task_completions[task]) + end + spaces_completed[space] = completed + end + + # Choose the earliest-available memory space and processor + # FIXME: Consider move time + move_time = UInt64(0) + local our_space_completed + while true + our_space_completed, our_space = findmin(spaces_completed) + our_space_procs = filter(proc->proc in all_procs, processors(our_space)) + if isempty(our_space_procs) + delete!(spaces_completed, our_space) + continue + end + our_proc = rand(our_space_procs) + break + end + + sstate.task_to_spec[task] = spec + sstate.assignments[task] = our_space + sstate.task_completions[task] = our_space_completed + move_time + task_time + elseif scheduler == :roundrobin + our_proc = all_procs[proc_idx] + if task_scope == scope + # all_procs is already limited to scope + else + if isa(constrain(task_scope, scope), InvalidScope) + throw(Sch.SchedulingException("Scopes are not compatible: $(scope), $(task_scope)")) + end + while !proc_in_scope(our_proc, task_scope) + proc_idx = mod1(proc_idx + 1, length(all_procs)) + our_proc = all_procs[proc_idx] + end + end + else + error("Invalid scheduler: $sched") + end + @assert our_proc in all_procs + our_space = only(memory_spaces(our_proc)) + + # Find the scope for this task (and its copies) + task_scope = @something(spec.options.compute_scope, spec.options.scope, DefaultScope()) + if task_scope == scope + # Optimize for the common case, cache the proc=>scope mapping + our_scope = get!(proc_to_scope_lfu, our_proc) do + our_procs = filter(proc->proc in all_procs, collect(processors(our_space))) + return constrain(UnionScope(map(ExactScope, our_procs)...), scope) + end + else + # Use the provided scope and constrain it to the available processors + our_procs = filter(proc->proc in all_procs, collect(processors(our_space))) + our_scope = constrain(UnionScope(map(ExactScope, our_procs)...), task_scope) + end + if our_scope isa InvalidScope + throw(Sch.SchedulingException("Scopes are not compatible: $(our_scope.x), $(our_scope.y)")) + end + + f = spec.fargs[1] + # FIXME: May not be correct to move this under uniformity + #f.value = move(default_processor(), our_proc, value(f)) + @dagdebug nothing :spawn_datadeps "($(repr(value(f)))) Scheduling: $our_proc ($our_space)" + + # Copy raw task arguments for analysis + # N.B. Used later for checking dependencies + task_args = map_or_ntuple(idx->copy(spec.fargs[idx]), spec.fargs) + + # Populate all task dependencies + task_arg_ws = populate_task_info!(state, task_args, spec, task) + + # Truncate the history for each argument + map_or_ntuple(task_arg_ws) do idx + arg_ws = task_arg_ws[idx] + map_or_ntuple(arg_ws.deps) do dep_idx + dep = arg_ws.deps[dep_idx] + truncate_history!(state, dep.arg_w) + end + return + end + + # Copy args from local to remote + remote_args = map_or_ntuple(task_arg_ws) do idx + arg_ws = task_arg_ws[idx] + arg = arg_ws.arg + pos = raw_position(arg_ws.pos) + + # Is the data written previously or now? + if !arg_ws.may_alias + @dagdebug nothing :spawn_datadeps "($(repr(value(f))))[$(idx-1)] Skipped copy-to (immutable)" + return arg + end + + # Is the data writeable? + if !arg_ws.inplace_move + @dagdebug nothing :spawn_datadeps "($(repr(value(f))))[$(idx-1)] Skipped copy-to (non-writeable)" + return arg + end + + # Is the source of truth elsewhere? + arg_remote = get_or_generate_slot!(state, our_space, arg) + map_or_ntuple(arg_ws.deps) do dep_idx + dep = arg_ws.deps[dep_idx] + arg_w = dep.arg_w + dep_mod = arg_w.dep_mod + remainder, _ = compute_remainder_for_arg!(state, our_space, arg_w, write_num) + if remainder isa MultiRemainderAliasing + enqueue_remainder_copy_to!(state, our_space, arg_w, remainder, value(f), idx, our_scope, task, write_num) + elseif remainder isa FullCopy + enqueue_copy_to!(state, our_space, arg_w, value(f), idx, our_scope, task, write_num) + else + @assert remainder isa NoAliasing "Expected NoAliasing, got $(typeof(remainder))" + @dagdebug nothing :spawn_datadeps "($(repr(value(f))))[$(idx-1)][$dep_mod] Skipped copy-to (up-to-date): $our_space" + end + end + return arg_remote + end + write_num += 1 + + # Validate that we're not accidentally performing a copy + map_or_ntuple(task_arg_ws) do idx + arg_ws = task_arg_ws[idx] + arg = remote_args[idx] + + # Get the dependencies again as (dep_mod, readdep, writedep) + deps = map_or_ntuple(arg_ws.deps) do dep_idx + dep = arg_ws.deps[dep_idx] + (dep.arg_w.dep_mod, dep.readdep, dep.writedep) + end + + # Check that any mutable and written arguments are already in the correct space + # N.B. We only do this check when the argument supports in-place + # moves, because for the moment, we are not guaranteeing updates or + # write-back of results + if is_writedep(arg, deps, task) && arg_ws.may_alias && arg_ws.inplace_move + arg_space = memory_space(arg) + @assert arg_space == our_space "($(repr(value(f))))[$(idx-1)] Tried to pass $(typeof(arg)) from $arg_space to $our_space" + end + end + + # Calculate this task's syncdeps + if spec.options.syncdeps === nothing + spec.options.syncdeps = Set{ThunkSyncdep}() + end + syncdeps = spec.options.syncdeps + map_or_ntuple(task_arg_ws) do idx + arg_ws = task_arg_ws[idx] + arg = arg_ws.arg + arg_ws.may_alias || return + arg_ws.inplace_move || return + map_or_ntuple(arg_ws.deps) do dep_idx + dep = arg_ws.deps[dep_idx] + arg_w = dep.arg_w + ainfo = aliasing!(state, our_space, arg_w) + dep_mod = arg_w.dep_mod + if dep.writedep + @dagdebug nothing :spawn_datadeps "($(repr(value(f))))[$(idx-1)][$dep_mod] Syncing as writer" + get_write_deps!(state, our_space, ainfo, write_num, syncdeps) + else + @dagdebug nothing :spawn_datadeps "($(repr(value(f))))[$(idx-1)][$dep_mod] Syncing as reader" + get_read_deps!(state, our_space, ainfo, write_num, syncdeps) + end + end + return + end + @dagdebug nothing :spawn_datadeps "($(repr(value(f)))) Task has $(length(syncdeps)) syncdeps" + + # Launch user's task + new_fargs = map_or_ntuple(task_arg_ws) do idx + if is_typed(spec) + return TypedArgument(task_arg_ws[idx].pos, remote_args[idx]) + else + return Argument(task_arg_ws[idx].pos, remote_args[idx]) + end + end + new_spec = DTaskSpec(new_fargs, spec.options) + new_spec.options.scope = our_scope + new_spec.options.exec_scope = our_scope + new_spec.options.occupancy = Dict(Any=>0) + enqueue!(queue.upper_queue, DTaskPair(new_spec, task)) + + # Update read/write tracking for arguments + map_or_ntuple(task_arg_ws) do idx + arg_ws = task_arg_ws[idx] + arg = arg_ws.arg + arg_ws.may_alias || return + arg_ws.inplace_move || return + for dep in arg_ws.deps + arg_w = dep.arg_w + ainfo = aliasing!(state, our_space, arg_w) + dep_mod = arg_w.dep_mod + if dep.writedep + @dagdebug nothing :spawn_datadeps "($(repr(value(f))))[$(idx-1)][$dep_mod] Task set as writer" + add_writer!(state, arg_w, our_space, ainfo, task, write_num) + else + add_reader!(state, arg_w, our_space, ainfo, task, write_num) + end + end + return + end + + write_num += 1 + proc_idx = mod1(proc_idx + 1, length(all_procs)) + + return write_num, proc_idx +end diff --git a/src/datadeps/remainders.jl b/src/datadeps/remainders.jl new file mode 100644 index 00000000..b420ca5d --- /dev/null +++ b/src/datadeps/remainders.jl @@ -0,0 +1,443 @@ +# Remainder tracking and computation functions + +""" + RemainderAliasing{S<:MemorySpace} <: AbstractAliasing + +Represents the memory spans that remain after subtracting some regions from a base aliasing object. +This is used to perform partial data copies that only update the "remainder" regions. +""" +struct RemainderAliasing{S<:MemorySpace} <: AbstractAliasing + space::S + spans::Vector{Tuple{LocalMemorySpan,LocalMemorySpan}} + syncdeps::Set{ThunkSyncdep} +end +RemainderAliasing(space::S, spans::Vector{Tuple{LocalMemorySpan,LocalMemorySpan}}, syncdeps::Set{ThunkSyncdep}) where S = + RemainderAliasing{S}(space, spans, syncdeps) + +memory_spans(ra::RemainderAliasing) = ra.spans + +Base.hash(ra::RemainderAliasing, h::UInt) = hash(ra.spans, hash(RemainderAliasing, h)) +Base.:(==)(ra1::RemainderAliasing, ra2::RemainderAliasing) = ra1.spans == ra2.spans + +# Add will_alias support for RemainderAliasing +function will_alias(x::RemainderAliasing, y::AbstractAliasing) + return will_alias(memory_spans(x), memory_spans(y)) +end + +function will_alias(x::AbstractAliasing, y::RemainderAliasing) + return will_alias(memory_spans(x), memory_spans(y)) +end + +function will_alias(x::RemainderAliasing, y::RemainderAliasing) + return will_alias(memory_spans(x), memory_spans(y)) +end + +struct MultiRemainderAliasing <: AbstractAliasing + remainders::Vector{<:RemainderAliasing} +end +MultiRemainderAliasing() = MultiRemainderAliasing(RemainderAliasing[]) + +memory_spans(mra::MultiRemainderAliasing) = vcat(memory_spans.(mra.remainders)...) + +Base.hash(mra::MultiRemainderAliasing, h::UInt) = hash(mra.remainders, hash(MultiRemainderAliasing, h)) +Base.:(==)(mra1::MultiRemainderAliasing, mra2::MultiRemainderAliasing) = mra1.remainders == mra2.remainders + +#= FIXME: Integrate with main documentation +Problem statement: + +Remainder copy calculation needs to ensure that, for a given argument and +dependency modifier, and for a given target memory space, any data not yet +updated (whether through this arg or through another that aliases) is added to +the remainder, while any data that has been updated is not in the remainder. +Remainder copies may be multi-part, as data may be spread across multiple other +memory spaces. + +Ainfo is not alone sufficient to identify the combination of argument and +dependency modifier, as ainfo is specific to an allocation in a given memory +space. Thus, this combination needs to be tracked together, and separately from +memory space. However, information may span multiple memory spaces (and thus +multiple ainfos), so we should try to make queries of cross-memory space +information fast, as they will need to be performed for every task, for every +combination. + +Game Plan: + +- Use ArgumentWrapper to track this combination throughout the codebase, ideally generated just once +- Maintain the keying of remote_args only on argument, as the dependency modifier doesn’t affect the argument being passed into the task, so it should not factor into generating and tracking remote argument copies +- Add a structure to track the mapping from ArgumentWrapper to memory space to ainfo, as a quick way to lookup all ainfos needing to be considered +- When considering a remainder copy, only look at a single memory space’s ainfos at a time, as the ainfos should overlap exactly the same way on any memory space, and this allows us to use ainfo_overlaps to track overlaps +- Remainder copies will need to separately consider the source memory space, and the destination memory space when acquiring spans to copy to/from +- Memory spans for ainfos generated from the same ArgumentWrapper should be assumed to be paired in the same order, regardless of memory space, to ensure we can perform the translation from source to destination span address + - Alternatively, we might provide an API to take source and destination ainfos, and desired remainder memory spans, which then performs the copy for us +- When a task or copy writes to arguments, we should record this happening for all overlapping ainfos, in a manner that will be efficient to query from another memory space. We can probably walk backwards and attach this to a structure keyed on ArgumentWrapper, as that will be very efficient for later queries (because the history will now be linearized in one vector). +- Remainder copies will need to know, for all overlapping ainfos of the ArgumentWrapper ainfo at the target memory space, how recently that ainfo was updated relative to other ainfos, and relative to how recently the target ainfo was written. + - The last time the target ainfo was written is the furthest back we need to consider, as the target data must have been fully up-to-date when that write completed. + - Consideration of updates should start at most recent first, walking backwards in time, as the most recent updates contain the up-to-date data. + - For each span under consideration, we should subtract from it the current remainder set, to ensure we only copy up-to-date data. + - We must add that span portion to the remainder set no matter what, but if it was updated on the target memory space, we don’t need to schedule a copy for it, since it’s already where it needs to be. + - Even before the last target write is seen, we are allowed to stop searching if we find that our target ainfo is fully covered (because this implies that the target ainfo is fully out-of-date). +=# + +struct FullCopy end + +""" + compute_remainder_for_arg!(state::DataDepsState, + target_space::MemorySpace, + arg_w::ArgumentWrapper) + +Computes what remainder regions need to be copied to `target_space` before a task can access `arg_w`. +Returns a `MultiRemainderAliasing` object representing the remainder, or `NoAliasing()` if no remainder needed. + +The algorithm starts by collecting the memory spans of `arg_w` in `target_space` - this is the "remainder". +When this remainder is empty, the algorithm will be finished. +Additionally, a dictionary is created to store the source and destination +memory spans (for each source memory space) that will be used to create the +`MultiRemainderAliasing` object - this is the "tracker". + +The algorithm walks backwards through the `arg_history` vector for `arg_w` +(which is an ordered list of all overlapping ainfos that were directy written to (potentially in a different memory space than `target_space`) +since the last time this `arg_w` was written to). If this ainfo is in `target_space`, +then it is not under consideration; it is simply subtraced from the remainder with `subtract_remainder!`, +and the algorithm goes to the next ainfo. Otherwise, the algorithm will consider this ainfo for tracking. + +For each overlapping ainfo (which lives in a different memory space than `target_space`) to be tracked, there exists a corresponding "mirror" ainfo in +`target_space`, which is the equivalent of the overlapping ainfo, but in +`target_space`. This mirror ainfo is assumed to have an identical number of memory spans as the overlapping ainfo, +and each memory span is assumed to be identical in size, but not necessarily identical in address. + +These three sets of memory spans (from the remainder, the overlapping ainfo, and the mirror ainfo) are then passed to `schedule_aliasing!`. +This call will subtract the spans of the mirror ainfo from the remainder (as the two live in the same memory space and thus can be directly compared), +and will update the remainder accordingly. +Additionaly, it will also use this subtraction to update the tracker, by adding the equivalent spans (mapped from mirror ainfo to overlapping ainfo) to the tracker as the source, +and the spans of the remainder as the destination. + +If the history is exhausted without the remainder becoming empty, then the +remaining data in `target_space` is assumed to be up-to-date (as the latest write +to `arg_w` is the furthest back we need to consider). + +Finally, the tracker is converted into a `MultiRemainderAliasing` object, +and returned. +""" +function compute_remainder_for_arg!(state::DataDepsState, + target_space::MemorySpace, + arg_w::ArgumentWrapper, + write_num::Int; compute_syncdeps::Bool=true) + @label restart + + # Determine all memory spaces of the history + spaces_set = Set{MemorySpace}() + push!(spaces_set, target_space) + owner_space = state.arg_owner[arg_w] + push!(spaces_set, owner_space) + for entry in state.arg_history[arg_w] + push!(spaces_set, entry.space) + end + spaces = collect(spaces_set) + N = length(spaces) + + # Lookup all memory spans for arg_w in these spaces + target_ainfos = Vector{Vector{LocalMemorySpan}}() + for space in spaces + target_space_ainfo = aliasing!(state, space, arg_w) + spans = memory_spans(target_space_ainfo) + push!(target_ainfos, LocalMemorySpan.(spans)) + end + nspans = length(first(target_ainfos)) + + # FIXME: This is a hack to ensure that we don't miss any history generated by aliasing(...) + for entry in state.arg_history[arg_w] + if !in(entry.space, spaces) + @opcounter :compute_remainder_for_arg_restart + @goto restart + end + end + + # We may only need to schedule a full copy from the origin space to the + # target space if this is the first time we've written to `arg_w` + if isempty(state.arg_history[arg_w]) + if owner_space != target_space + return FullCopy(), 0 + else + return NoAliasing(), 0 + end + end + + # Create our remainder as an interval tree over all target ainfos + remainder = IntervalTree{ManyMemorySpan{N}}(ManyMemorySpan{N}(ntuple(i -> target_ainfos[i][j], N)) for j in 1:nspans) + + # Create our tracker + tracker = Dict{MemorySpace,Tuple{Vector{Tuple{LocalMemorySpan,LocalMemorySpan}},Set{ThunkSyncdep}}}() + + # Walk backwards through the history of writes to this target + # other_ainfo is the overlapping ainfo that was written to + # other_space is the memory space of the overlapping ainfo + last_idx = length(state.arg_history[arg_w]) + for idx in length(state.arg_history[arg_w]):-1:0 + if isempty(remainder) + # All done! + last_idx = idx + break + end + + if idx > 0 + other_entry = state.arg_history[arg_w][idx] + other_ainfo = other_entry.ainfo + other_space = other_entry.space + else + # If we've reached the end of the history, evaluate ourselves + other_ainfo = aliasing!(state, owner_space, arg_w) + other_space = owner_space + end + + # Lookup all memory spans for arg_w in these spaces + other_remote_arg_w = state.ainfo_arg[other_ainfo] + other_arg_w = ArgumentWrapper(state.remote_arg_to_original[other_remote_arg_w.arg], other_remote_arg_w.dep_mod) + other_ainfos = Vector{Vector{LocalMemorySpan}}() + for space in spaces + other_space_ainfo = aliasing!(state, space, other_arg_w) + spans = memory_spans(other_space_ainfo) + push!(other_ainfos, LocalMemorySpan.(spans)) + end + nspans = length(first(other_ainfos)) + other_many_spans = [ManyMemorySpan{N}(ntuple(i -> other_ainfos[i][j], N)) for j in 1:nspans] + + if other_space == target_space + # Only subtract, this data is already up-to-date in target_space + # N.B. We don't add to syncdeps here, because we'll see this ainfo + # in get_write_deps! + @opcounter :compute_remainder_for_arg_subtract + subtract_spans!(remainder, other_many_spans) + continue + end + + # Subtract from remainder and schedule copy in tracker + other_space_idx = something(findfirst(==(other_space), spaces)) + target_space_idx = something(findfirst(==(target_space), spaces)) + tracker_other_space = get!(tracker, other_space) do + (Vector{Tuple{LocalMemorySpan,LocalMemorySpan}}(), Set{ThunkSyncdep}()) + end + @opcounter :compute_remainder_for_arg_schedule + schedule_remainder!(tracker_other_space[1], other_space_idx, target_space_idx, remainder, other_many_spans) + if compute_syncdeps + @assert haskey(state.ainfos_owner, other_ainfo) "[idx $idx] ainfo $(typeof(other_ainfo)) has no owner" + get_read_deps!(state, other_space, other_ainfo, write_num, tracker_other_space[2]) + end + end + + if isempty(tracker) + return NoAliasing(), 0 + end + + # Return scheduled copies and the index of the last ainfo we considered + mra = MultiRemainderAliasing() + for space in spaces + if haskey(tracker, space) + spans, syncdeps = tracker[space] + if !isempty(spans) + push!(mra.remainders, RemainderAliasing(space, spans, syncdeps)) + end + end + end + return mra, last_idx +end + +### Memory Span Set Operations for Remainder Computation + +""" + schedule_remainder!(tracker::Vector, source_space_idx::Int, dest_space_idx::Int, remainder::IntervalTree, other_many_spans::Vector{ManyMemorySpan{N}}) + +Calculates the difference between `remainder` and `other_many_spans`, subtracts +it from `remainder`, and then adds that difference to `tracker` as a scheduled +copy from `other_many_spans` to the subtraced portion of `remainder`. +""" +function schedule_remainder!(tracker::Vector, source_space_idx::Int, dest_space_idx::Int, remainder::IntervalTree, other_many_spans::Vector{ManyMemorySpan{N}}) where N + diff = Vector{ManyMemorySpan{N}}() + subtract_spans!(remainder, other_many_spans, diff) + + for span in diff + source_span = span.spans[source_space_idx] + dest_span = span.spans[dest_space_idx] + push!(tracker, (source_span, dest_span)) + end +end + +### Remainder copy functions + +""" + enqueue_remainder_copy_to!(state::DataDepsState, f, target_ainfo::AliasingWrapper, remainder_aliasing, dep_mod, arg, idx, + our_space::MemorySpace, our_scope, task::DTask, write_num::Int) + +Enqueues a copy operation to update the remainder regions of an object before a task runs. +""" +function enqueue_remainder_copy_to!(state::DataDepsState, dest_space::MemorySpace, arg_w::ArgumentWrapper, remainder_aliasing::MultiRemainderAliasing, + f, idx, dest_scope, task, write_num::Int) + for remainder in remainder_aliasing.remainders + @assert !isempty(remainder.spans) + enqueue_remainder_copy_to!(state, dest_space, arg_w, remainder, f, idx, dest_scope, task, write_num) + end +end +function enqueue_remainder_copy_to!(state::DataDepsState, dest_space::MemorySpace, arg_w::ArgumentWrapper, remainder_aliasing::RemainderAliasing, + f, idx, dest_scope, task, write_num::Int) + dep_mod = arg_w.dep_mod + + # Find the source space for the remainder data + # We need to find where the best version of the target data lives that hasn't been + # overwritten by more recent partial updates + source_space = remainder_aliasing.space + + @dagdebug nothing :spawn_datadeps "($(repr(f)))[$(idx-1)][$dep_mod] Enqueueing remainder copy-to for $(typeof(arg_w.arg))[$(arg_w.dep_mod)]: $source_space => $dest_space" + + # Get the source and destination arguments + arg_dest = state.remote_args[dest_space][arg_w.arg] + arg_source = get_or_generate_slot!(state, source_space, arg_w.arg) + + # Create a copy task for the remainder + remainder_syncdeps = Set{Any}() + target_ainfo = aliasing!(state, dest_space, arg_w) + for syncdep in remainder_aliasing.syncdeps + push!(remainder_syncdeps, syncdep) + end + empty!(remainder_aliasing.syncdeps) # We can't bring these to move! + get_write_deps!(state, dest_space, target_ainfo, write_num, remainder_syncdeps) + + @dagdebug nothing :spawn_datadeps "($(repr(f)))[$(idx-1)][$dep_mod] Remainder copy-to has $(length(remainder_syncdeps)) syncdeps" + + # Launch the remainder copy task + copy_task = Dagger.@spawn scope=dest_scope exec_scope=dest_scope syncdeps=remainder_syncdeps meta=true Dagger.move!(remainder_aliasing, dest_space, source_space, arg_dest, arg_source) + + # This copy task becomes a new writer for the target region + add_writer!(state, arg_w, dest_space, target_ainfo, copy_task, write_num) +end +""" + enqueue_remainder_copy_from!(state::DataDepsState, target_ainfo::AliasingWrapper, arg, remainder_aliasing, + origin_space::MemorySpace, origin_scope, write_num::Int) + +Enqueues a copy operation to update the remainder regions of an object back to the original space. +""" +function enqueue_remainder_copy_from!(state::DataDepsState, dest_space::MemorySpace, arg_w::ArgumentWrapper, remainder_aliasing::MultiRemainderAliasing, + dest_scope, write_num::Int) + for remainder in remainder_aliasing.remainders + @assert !isempty(remainder.spans) + enqueue_remainder_copy_from!(state, dest_space, arg_w, remainder, dest_scope, write_num) + end +end +function enqueue_remainder_copy_from!(state::DataDepsState, dest_space::MemorySpace, arg_w::ArgumentWrapper, remainder_aliasing::RemainderAliasing, + dest_scope, write_num::Int) + dep_mod = arg_w.dep_mod + + # Find the source space for the remainder data + # We need to find where the best version of the target data lives that hasn't been + # overwritten by more recent partial updates + source_space = remainder_aliasing.space + + @dagdebug nothing :spawn_datadeps "($(typeof(arg_w.arg)))[$dep_mod] Enqueueing remainder copy-from for: $source_space => $dest_space" + + # Get the source and destination arguments + arg_dest = state.remote_args[dest_space][arg_w.arg] + arg_source = get_or_generate_slot!(state, source_space, arg_w.arg) + + # Create a copy task for the remainder + remainder_syncdeps = Set{Any}() + target_ainfo = aliasing!(state, dest_space, arg_w) + for syncdep in remainder_aliasing.syncdeps + push!(remainder_syncdeps, syncdep) + end + empty!(remainder_aliasing.syncdeps) # We can't bring these to move! + get_write_deps!(state, dest_space, target_ainfo, write_num, remainder_syncdeps) + + @dagdebug nothing :spawn_datadeps "($(typeof(arg_w.arg)))[$dep_mod] Remainder copy-from has $(length(remainder_syncdeps)) syncdeps" + + # Launch the remainder copy task + copy_task = Dagger.@spawn scope=dest_scope exec_scope=dest_scope syncdeps=remainder_syncdeps meta=true Dagger.move!(remainder_aliasing, dest_space, source_space, arg_dest, arg_source) + + # This copy task becomes a new writer for the target region + add_writer!(state, arg_w, dest_space, target_ainfo, copy_task, write_num) +end + +# FIXME: Document me +function enqueue_copy_to!(state::DataDepsState, dest_space::MemorySpace, arg_w::ArgumentWrapper, + f, idx, dest_scope, task, write_num::Int) + dep_mod = arg_w.dep_mod + source_space = state.arg_owner[arg_w] + target_ainfo = aliasing!(state, dest_space, arg_w) + + @dagdebug nothing :spawn_datadeps "($(repr(f)))[$(idx-1)][$dep_mod] Enqueueing full copy-to for $(typeof(arg_w.arg))[$(arg_w.dep_mod)]: $source_space => $dest_space" + + # Get the source and destination arguments + arg_dest = state.remote_args[dest_space][arg_w.arg] + arg_source = get_or_generate_slot!(state, source_space, arg_w.arg) + + # Create a copy task for the remainder + copy_syncdeps = Set{Any}() + source_ainfo = aliasing!(state, source_space, arg_w) + target_ainfo = aliasing!(state, dest_space, arg_w) + get_read_deps!(state, source_space, source_ainfo, write_num, copy_syncdeps) + get_write_deps!(state, dest_space, target_ainfo, write_num, copy_syncdeps) + + @dagdebug nothing :spawn_datadeps "($(repr(f)))[$(idx-1)][$dep_mod] Full copy-to has $(length(copy_syncdeps)) syncdeps" + + # Launch the remainder copy task + copy_task = Dagger.@spawn scope=dest_scope exec_scope=dest_scope syncdeps=copy_syncdeps meta=true Dagger.move!(dep_mod, dest_space, source_space, arg_dest, arg_source) + + # This copy task becomes a new writer for the target region + add_writer!(state, arg_w, dest_space, target_ainfo, copy_task, write_num) +end +function enqueue_copy_from!(state::DataDepsState, dest_space::MemorySpace, arg_w::ArgumentWrapper, + dest_scope, write_num::Int) + dep_mod = arg_w.dep_mod + source_space = state.arg_owner[arg_w] + target_ainfo = aliasing!(state, dest_space, arg_w) + + @dagdebug nothing :spawn_datadeps "($(typeof(arg_w.arg)))[$dep_mod] Enqueueing full copy-from: $source_space => $dest_space" + + # Get the source and destination arguments + arg_dest = state.remote_args[dest_space][arg_w.arg] + arg_source = get_or_generate_slot!(state, source_space, arg_w.arg) + + # Create a copy task for the remainder + copy_syncdeps = Set{Any}() + source_ainfo = aliasing!(state, source_space, arg_w) + target_ainfo = aliasing!(state, dest_space, arg_w) + get_read_deps!(state, source_space, source_ainfo, write_num, copy_syncdeps) + get_write_deps!(state, dest_space, target_ainfo, write_num, copy_syncdeps) + + @dagdebug nothing :spawn_datadeps "($(typeof(arg_w.arg)))[$dep_mod] Full copy-from has $(length(copy_syncdeps)) syncdeps" + + # Launch the remainder copy task + copy_task = Dagger.@spawn scope=dest_scope exec_scope=dest_scope syncdeps=copy_syncdeps meta=true Dagger.move!(dep_mod, dest_space, source_space, arg_dest, arg_source) + + # This copy task becomes a new writer for the target region + add_writer!(state, arg_w, dest_space, target_ainfo, copy_task, write_num) +end + +# Main copy function for RemainderAliasing +function move!(dep_mod::RemainderAliasing{S}, to_space::MemorySpace, from_space::MemorySpace, to::Chunk, from::Chunk) where S + # Get the source data for each span + copies = remotecall_fetch(root_worker_id(from_space), dep_mod) do dep_mod + copies = Vector{UInt8}[] + for (from_span, _) in dep_mod.spans + copy = Vector{UInt8}(undef, from_span.len) + GC.@preserve copy begin + from_ptr = Ptr{UInt8}(from_span.ptr) + to_ptr = Ptr{UInt8}(pointer(copy)) + unsafe_copyto!(to_ptr, from_ptr, from_span.len) + end + push!(copies, copy) + end + return copies + end + + # Copy the data into the destination object + for (copy, (_, to_span)) in zip(copies, dep_mod.spans) + GC.@preserve copy begin + from_ptr = Ptr{UInt8}(pointer(copy)) + to_ptr = Ptr{UInt8}(to_span.ptr) + unsafe_copyto!(to_ptr, from_ptr, to_span.len) + end + end + + # Ensure that the data is visible + Core.Intrinsics.atomic_fence(:release) + + return +end diff --git a/src/memory-spaces.jl b/src/memory-spaces.jl index 9f65a1a2..312351a3 100644 --- a/src/memory-spaces.jl +++ b/src/memory-spaces.jl @@ -1,5 +1,3 @@ -abstract type MemorySpace end - struct CPURAMMemorySpace <: MemorySpace owner::Int end @@ -30,7 +28,7 @@ processors(space::CPURAMMemorySpace) = ### In-place Data Movement function unwrap(x::Chunk) - @assert root_worker_id(x.processor) == myid() + @assert x.handle.owner == myid() MemPool.poolget(x.handle) end move!(dep_mod, to_space::MemorySpace, from_space::MemorySpace, to::T, from::F) where {T,F} = @@ -92,41 +90,16 @@ end may_alias(::MemorySpace, ::MemorySpace) = true may_alias(space1::CPURAMMemorySpace, space2::CPURAMMemorySpace) = space1.owner == space2.owner -struct RemotePtr{T,S<:MemorySpace} <: Ref{T} - addr::UInt - space::S -end -RemotePtr{T}(addr::UInt, space::S) where {T,S} = RemotePtr{T,S}(addr, space) -RemotePtr{T}(ptr::Ptr{V}, space::S) where {T,V,S} = RemotePtr{T,S}(UInt(ptr), space) -RemotePtr{T}(ptr::Ptr{V}) where {T,V} = RemotePtr{T}(UInt(ptr), CPURAMMemorySpace(myid())) -Base.convert(::Type{RemotePtr}, x::Ptr{T}) where T = - RemotePtr(UInt(x), CPURAMMemorySpace(myid())) -Base.convert(::Type{<:RemotePtr{V}}, x::Ptr{T}) where {V,T} = - RemotePtr{V}(UInt(x), CPURAMMemorySpace(myid())) -Base.:+(ptr::RemotePtr{T}, offset::Integer) where T = RemotePtr{T}(ptr.addr + offset, ptr.space) -Base.:-(ptr::RemotePtr{T}, offset::Integer) where T = RemotePtr{T}(ptr.addr - offset, ptr.space) -function Base.isless(ptr1::RemotePtr, ptr2::RemotePtr) - @assert ptr1.space == ptr2.space - return ptr1.addr < ptr2.addr -end - -struct MemorySpan{S} - ptr::RemotePtr{Cvoid,S} - len::UInt -end -MemorySpan(ptr::RemotePtr{Cvoid,S}, len::Integer) where S = - MemorySpan{S}(ptr, UInt(len)) - abstract type AbstractAliasing end memory_spans(::T) where T<:AbstractAliasing = throw(ArgumentError("Must define `memory_spans` for `$T`")) memory_spans(x) = memory_spans(aliasing(x)) memory_spans(x, T) = memory_spans(aliasing(x, T)) +### Type-generic aliasing info wrapper -struct AliasingWrapper <: AbstractAliasing +mutable struct AliasingWrapper <: AbstractAliasing inner::AbstractAliasing hash::UInt64 - AliasingWrapper(inner::AbstractAliasing) = new(inner, hash(inner)) end memory_spans(x::AliasingWrapper) = memory_spans(x.inner) @@ -135,8 +108,202 @@ equivalent_structure(x::AliasingWrapper, y::AliasingWrapper) = Base.hash(x::AliasingWrapper, h::UInt64) = hash(x.hash, h) Base.isequal(x::AliasingWrapper, y::AliasingWrapper) = x.hash == y.hash Base.:(==)(x::AliasingWrapper, y::AliasingWrapper) = x.hash == y.hash -will_alias(x::AliasingWrapper, y::AliasingWrapper) = - will_alias(x.inner, y.inner) +will_alias(x::AliasingWrapper, y::AliasingWrapper) = will_alias(x.inner, y.inner) + +### Small dictionary type + +struct SmallDict{K,V} <: AbstractDict{K,V} + keys::Vector{K} + vals::Vector{V} +end +SmallDict{K,V}() where {K,V} = SmallDict{K,V}(Vector{K}(), Vector{V}()) +function Base.getindex(d::SmallDict{K,V}, key) where {K,V} + key_idx = findfirst(==(convert(K, key)), d.keys) + if key_idx === nothing + throw(KeyError(key)) + end + return @inbounds d.vals[key_idx] +end +function Base.setindex!(d::SmallDict{K,V}, val, key) where {K,V} + key_conv = convert(K, key) + key_idx = findfirst(==(key_conv), d.keys) + if key_idx === nothing + push!(d.keys, key_conv) + push!(d.vals, convert(V, val)) + else + d.vals[key_idx] = convert(V, val) + end + return val +end +Base.haskey(d::SmallDict{K,V}, key) where {K,V} = in(convert(K, key), d.keys) +Base.keys(d::SmallDict) = d.keys +Base.length(d::SmallDict) = length(d.keys) +Base.iterate(d::SmallDict) = iterate(d, 1) +Base.iterate(d::SmallDict, state) = state > length(d.keys) ? nothing : (d.keys[state] => d.vals[state], state+1) + +### Type-stable lookup structure for AliasingWrappers + +struct AliasingLookup + # The set of memory spaces that are being tracked + spaces::Vector{MemorySpace} + # The set of AliasingWrappers that are being tracked + # One entry for each AliasingWrapper + ainfos::Vector{AliasingWrapper} + # The memory spaces for each AliasingWrapper + # One entry for each AliasingWrapper + ainfos_spaces::Vector{Vector{Int}} + # The spans for each AliasingWrapper in each memory space + # One entry for each AliasingWrapper + spans::Vector{SmallDict{Int,Vector{LocalMemorySpan}}} + # The set of AliasingWrappers that only exist in a single memory space + # One entry for each AliasingWrapper + ainfos_only_space::Vector{Int} + # The bounding span for each AliasingWrapper in each memory space + # One entry for each AliasingWrapper + bounding_spans::Vector{SmallDict{Int,LocalMemorySpan}} + # The interval tree of the bounding spans for each AliasingWrapper + # One entry for each MemorySpace + bounding_spans_tree::Vector{IntervalTree{LocatorMemorySpan{Int},UInt64}} + + AliasingLookup() = new(MemorySpace[], + AliasingWrapper[], + Vector{Int}[], + SmallDict{Int,Vector{LocalMemorySpan}}[], + Int[], + SmallDict{Int,LocalMemorySpan}[], + IntervalTree{LocatorMemorySpan{Int},UInt64}[]) +end +function Base.push!(lookup::AliasingLookup, ainfo::AliasingWrapper) + # Update the set of memory spaces and spans, + # and find the bounding spans for this AliasingWrapper + spaces_set = Set{MemorySpace}(lookup.spaces) + self_spaces_set = Set{Int}() + spans = SmallDict{Int,Vector{LocalMemorySpan}}() + for span in memory_spans(ainfo) + space = span.ptr.space + if !in(space, spaces_set) + push!(spaces_set, space) + push!(lookup.spaces, space) + push!(lookup.bounding_spans_tree, IntervalTree{LocatorMemorySpan{Int}}()) + end + space_idx = findfirst(==(space), lookup.spaces) + push!(self_spaces_set, space_idx) + spans_in_space = get!(Vector{LocalMemorySpan}, spans, space_idx) + push!(spans_in_space, LocalMemorySpan(span)) + end + push!(lookup.ainfos_spaces, collect(self_spaces_set)) + push!(lookup.spans, spans) + + # Update the set of AliasingWrappers + push!(lookup.ainfos, ainfo) + ainfo_idx = length(lookup.ainfos) + + # Check if the AliasingWrapper only exists in a single memory space + if length(self_spaces_set) == 1 + space_idx = only(self_spaces_set) + push!(lookup.ainfos_only_space, space_idx) + else + push!(lookup.ainfos_only_space, 0) + end + + # Add the bounding spans for this AliasingWrapper + bounding_spans = SmallDict{Int,LocalMemorySpan}() + for space_idx in keys(spans) + space_spans = spans[space_idx] + bound_start = minimum(span_start, space_spans) + bound_end = maximum(span_end, space_spans) + bounding_span = LocalMemorySpan(bound_start, bound_end - bound_start) + bounding_spans[space_idx] = bounding_span + insert!(lookup.bounding_spans_tree[space_idx], LocatorMemorySpan(bounding_span, ainfo_idx)) + end + push!(lookup.bounding_spans, bounding_spans) + + return ainfo_idx +end +struct AliasingLookupFinder + lookup::AliasingLookup + ainfo::AliasingWrapper + ainfo_idx::Int + spaces_idx::Vector{Int} + to_consider::Vector{Int} +end +Base.eltype(::AliasingLookupFinder) = AliasingWrapper +Base.IteratorSize(::AliasingLookupFinder) = Base.SizeUnknown() +# FIXME: We should use a Dict{UInt,Int} to find the ainfo_idx instead of linear search +function Base.intersect(lookup::AliasingLookup, ainfo::AliasingWrapper; ainfo_idx=nothing) + if ainfo_idx === nothing + ainfo_idx = something(findfirst(==(ainfo), lookup.ainfos)) + end + spaces_idx = lookup.ainfos_spaces[ainfo_idx] + to_consider_spans = LocatorMemorySpan{Int}[] + for space_idx in spaces_idx + bounding_spans_tree = lookup.bounding_spans_tree[space_idx] + self_bounding_span = LocatorMemorySpan(lookup.bounding_spans[ainfo_idx][space_idx], 0) + find_overlapping!(bounding_spans_tree, self_bounding_span, to_consider_spans; exact=false) + end + to_consider = Int[locator.owner for locator in to_consider_spans] + @assert all(to_consider .> 0) + return AliasingLookupFinder(lookup, ainfo, ainfo_idx, spaces_idx, to_consider) +end +Base.iterate(finder::AliasingLookupFinder) = iterate(finder, 1) +function Base.iterate(finder::AliasingLookupFinder, cursor_ainfo_idx) + ainfo_spaces = nothing + cursor_space_idx = 1 + + # New ainfos enter here + @label ainfo_restart + + # Check if we've exhausted all ainfos + if cursor_ainfo_idx > length(finder.to_consider) + return nothing + end + ainfo_idx = finder.to_consider[cursor_ainfo_idx] + + # Find the appropriate memory spaces for this ainfo + if ainfo_spaces === nothing + ainfo_spaces = finder.lookup.ainfos_spaces[ainfo_idx] + end + + # New memory spaces (for the same ainfo) enter here + @label space_restart + + # Check if we've exhausted all memory spaces for this ainfo, and need to move to the next ainfo + if cursor_space_idx > length(ainfo_spaces) + ainfo_idx += 1 + ainfo_spaces = nothing + cursor_space_idx = 1 + @goto ainfo_restart + end + + # Find the currently considered memory space for this ainfo + space_idx = ainfo_spaces[cursor_space_idx] + + # Check if this memory space is part of our target ainfo's spaces + if !(space_idx in finder.spaces_idx) + cursor_space_idx += 1 + @goto space_restart + end + + # Check if this ainfo's bounding span is part of our target ainfo's bounding span in this memory space + other_ainfo_bounding_span = finder.lookup.bounding_spans[ainfo_idx][space_idx] + self_bounding_span = finder.lookup.bounding_spans[finder.ainfo_idx][space_idx] + if !spans_overlap(other_ainfo_bounding_span, self_bounding_span) + cursor_space_idx += 1 + @goto space_restart + end + + # We have a overlapping bounds in the same memory space, so check if the ainfos are aliasing + # This is the slow path! + other_ainfo = finder.lookup.ainfos[ainfo_idx] + aliasing = will_alias(finder.ainfo, other_ainfo) + if !aliasing + ainfo_idx += 1 + @goto ainfo_restart + end + + # We overlap, so return the ainfo and the next ainfo index + return other_ainfo, ainfo_idx+1 +end struct NoAliasing <: AbstractAliasing end memory_spans(::NoAliasing) = MemorySpan{CPURAMMemorySpace}[] @@ -213,8 +380,14 @@ end aliasing(::String) = NoAliasing() # FIXME: Not necessarily true aliasing(::Symbol) = NoAliasing() aliasing(::Type) = NoAliasing() -aliasing(x::Chunk, T) = remotecall_fetch(root_worker_id(x.processor), x, T) do x, T - aliasing(unwrap(x), T) +function aliasing(x::Chunk, T) + @assert x.handle isa DRef + if root_worker_id(x.processor) == myid() + return aliasing(unwrap(x), T) + end + return remotecall_fetch(root_worker_id(x.processor), x, T) do x, T + aliasing(unwrap(x), T) + end end aliasing(x::Chunk) = remotecall_fetch(root_worker_id(x.processor), x) do x aliasing(unwrap(x)) @@ -279,7 +452,7 @@ function aliasing(x::SubArray{T,N,A}) where {T,N,A<:Array} return StridedAliasing{T,ndims(x),S}(RemotePtr{Cvoid}(pointer(parent(x))), RemotePtr{Cvoid}(pointer(x)), parentindices(x), - size(x), strides(parent(x))) + size(x), strides(x)) else # FIXME: Also ContiguousAliasing of container #return IteratedAliasing(x) @@ -401,71 +574,3 @@ function will_alias(x_span::MemorySpan, y_span::MemorySpan) y_end = y_span.ptr + y_span.len - 1 return x_span.ptr <= y_end && y_span.ptr <= x_end end - -struct ChunkView{N} - chunk::Chunk - slices::NTuple{N, Union{Int, AbstractRange{Int}, Colon}} -end - -function Base.view(c::Chunk, slices...) - if c.domain isa ArrayDomain - nd, sz = ndims(c.domain), size(c.domain) - nd == length(slices) || throw(DimensionMismatch("Expected $nd slices, got $(length(slices))")) - - for (i, s) in enumerate(slices) - if s isa Int - 1 ≤ s ≤ sz[i] || throw(ArgumentError("Index $s out of bounds for dimension $i (size $(sz[i]))")) - elseif s isa AbstractRange - isempty(s) && continue - 1 ≤ first(s) ≤ last(s) ≤ sz[i] || throw(ArgumentError("Range $s out of bounds for dimension $i (size $(sz[i]))")) - elseif s === Colon() - continue - else - throw(ArgumentError("Invalid slice type $(typeof(s)) at dimension $i, Expected Type of Int, AbstractRange, or Colon")) - end - end - end - - return ChunkView(c, slices) -end - -Base.view(c::DTask, slices...) = view(fetch(c; raw=true), slices...) - -function aliasing(x::ChunkView{N}) where N - remotecall_fetch(root_worker_id(x.chunk.processor), x.chunk, x.slices) do x, slices - x = unwrap(x) - v = view(x, slices...) - return aliasing(v) - end -end -memory_space(x::ChunkView) = memory_space(x.chunk) -isremotehandle(x::ChunkView) = true - -#= -function move!(dep_mod, to_space::MemorySpace, from_space::MemorySpace, to::ChunkView, from::ChunkView) - to_w = root_worker_id(to_space) - @assert to_w == myid() - to_raw = unwrap(to.chunk) - from_w = root_worker_id(from_space) - from_raw = to_w == from_w ? unwrap(from.chunk) : remotecall_fetch(f->copy(unwrap(f)), from_w, from.chunk) - from_view = view(from_raw, from.slices...) - to_view = view(to_raw, to.slices...) - move!(dep_mod, to_space, from_space, to_view, from_view) - return -end -=# - -function move(from_proc::Processor, to_proc::Processor, slice::ChunkView) - if from_proc == to_proc - return view(unwrap(slice.chunk), slice.slices...) - else - # Need to copy the underlying data, so collapse the view - from_w = root_worker_id(from_proc) - data = remotecall_fetch(from_w, slice.chunk, slice.slices) do chunk, slices - copy(view(unwrap(chunk), slices...)) - end - return move(from_proc, to_proc, data) - end -end - -Base.fetch(slice::ChunkView) = view(fetch(slice.chunk), slice.slices...) \ No newline at end of file diff --git a/src/queue.jl b/src/queue.jl index c8c6007e..37947a0a 100644 --- a/src/queue.jl +++ b/src/queue.jl @@ -1,32 +1,63 @@ -mutable struct DTaskSpec - fargs::Vector{Argument} +mutable struct DTaskSpec{typed,FA<:Tuple} + _fargs::Vector{Argument} + _typed_fargs::FA options::Options end +DTaskSpec(fargs::Vector{Argument}, options::Options) = + DTaskSpec{false, Tuple{}}(fargs, (), options) +DTaskSpec(fargs::FA, options::Options) where FA = + DTaskSpec{true, FA}(Argument[], fargs, options) +is_typed(spec::DTaskSpec{typed}) where typed = typed +function Base.getproperty(spec::DTaskSpec{typed}, field::Symbol) where typed + if field === :fargs + if typed + return getfield(spec, :_typed_fargs) + else + return getfield(spec, :_fargs) + end + else + return getfield(spec, field) + end +end + +struct DTaskPair + spec::DTaskSpec + task::DTask +end +is_typed(pair::DTaskPair) = is_typed(pair.spec) +Base.iterate(pair::DTaskPair) = (pair.spec, true) +function Base.iterate(pair::DTaskPair, state::Bool) + if state + return (pair.task, false) + else + return nothing + end +end abstract type AbstractTaskQueue end function enqueue! end struct DefaultTaskQueue <: AbstractTaskQueue end -enqueue!(::DefaultTaskQueue, spec::Pair{DTaskSpec,DTask}) = - eager_launch!(spec) -enqueue!(::DefaultTaskQueue, specs::Vector{Pair{DTaskSpec,DTask}}) = - eager_launch!(specs) +enqueue!(::DefaultTaskQueue, pair::DTaskPair) = + eager_launch!(pair) +enqueue!(::DefaultTaskQueue, pairs::Vector{DTaskPair}) = + eager_launch!(pairs) -enqueue!(spec::Pair{DTaskSpec,DTask}) = - enqueue!(get_options(:task_queue, DefaultTaskQueue()), spec) -enqueue!(specs::Vector{Pair{DTaskSpec,DTask}}) = - enqueue!(get_options(:task_queue, DefaultTaskQueue()), specs) +enqueue!(pair::DTaskPair) = + enqueue!(get_options(:task_queue, DefaultTaskQueue()), pair) +enqueue!(pairs::Vector{DTaskPair}) = + enqueue!(get_options(:task_queue, DefaultTaskQueue()), pairs) struct LazyTaskQueue <: AbstractTaskQueue - tasks::Vector{Pair{DTaskSpec,DTask}} - LazyTaskQueue() = new(Pair{DTaskSpec,DTask}[]) + tasks::Vector{DTaskPair} + LazyTaskQueue() = new(DTaskPair[]) end -function enqueue!(queue::LazyTaskQueue, spec::Pair{DTaskSpec,DTask}) - push!(queue.tasks, spec) +function enqueue!(queue::LazyTaskQueue, pair::DTaskPair) + push!(queue.tasks, pair) end -function enqueue!(queue::LazyTaskQueue, specs::Vector{Pair{DTaskSpec,DTask}}) - append!(queue.tasks, specs) +function enqueue!(queue::LazyTaskQueue, pairs::Vector{DTaskPair}) + append!(queue.tasks, pairs) end function spawn_bulk(f::Base.Callable) queue = LazyTaskQueue() @@ -50,25 +81,25 @@ function _add_prev_deps!(queue::InOrderTaskQueue, spec::DTaskSpec) push!(syncdeps, ThunkSyncdep(task)) end end -function enqueue!(queue::InOrderTaskQueue, spec::Pair{DTaskSpec,DTask}) +function enqueue!(queue::InOrderTaskQueue, pair::DTaskPair) if length(queue.prev_tasks) > 0 - _add_prev_deps!(queue, first(spec)) + _add_prev_deps!(queue, pair.spec) empty!(queue.prev_tasks) end - push!(queue.prev_tasks, last(spec)) - enqueue!(queue.upper_queue, spec) + push!(queue.prev_tasks, pair.task) + enqueue!(queue.upper_queue, pair) end -function enqueue!(queue::InOrderTaskQueue, specs::Vector{Pair{DTaskSpec,DTask}}) +function enqueue!(queue::InOrderTaskQueue, pairs::Vector{DTaskPair}) if length(queue.prev_tasks) > 0 - for (spec, task) in specs - _add_prev_deps!(queue, spec) + for pair in pairs + _add_prev_deps!(queue, pair.spec) end empty!(queue.prev_tasks) end - for (spec, task) in specs - push!(queue.prev_tasks, task) + for pair in pairs + push!(queue.prev_tasks, pair.task) end - enqueue!(queue.upper_queue, specs) + enqueue!(queue.upper_queue, pairs) end function spawn_sequential(f::Base.Callable) queue = InOrderTaskQueue(get_options(:task_queue, DefaultTaskQueue())) @@ -79,15 +110,15 @@ struct WaitAllQueue <: AbstractTaskQueue upper_queue::AbstractTaskQueue tasks::Vector{DTask} end -function enqueue!(queue::WaitAllQueue, spec::Pair{DTaskSpec,DTask}) - push!(queue.tasks, spec[2]) - enqueue!(queue.upper_queue, spec) +function enqueue!(queue::WaitAllQueue, pair::DTaskPair) + push!(queue.tasks, pair.task) + enqueue!(queue.upper_queue, pair) end -function enqueue!(queue::WaitAllQueue, specs::Vector{Pair{DTaskSpec,DTask}}) - for (_, task) in specs - push!(queue.tasks, task) +function enqueue!(queue::WaitAllQueue, pairs::Vector{DTaskPair}) + for pair in pairs + push!(queue.tasks, pair.task) end - enqueue!(queue.upper_queue, specs) + enqueue!(queue.upper_queue, pairs) end function wait_all(f; check_errors::Bool=false) queue = WaitAllQueue(get_options(:task_queue, DefaultTaskQueue()), DTask[]) diff --git a/src/submission.jl b/src/submission.jl index 2e7b1c83..4ff4f229 100644 --- a/src/submission.jl +++ b/src/submission.jl @@ -268,24 +268,29 @@ function eager_process_elem_submission_to_local!(id_map, arg::Argument) arg.value = Sch.ThunkID(id_map[value(arg).uid], value(arg).thunk_ref) end end -function eager_process_args_submission_to_local!(id_map, spec_pair::Pair{DTaskSpec,DTask}) - spec, task = spec_pair +function eager_process_elem_submission_to_local(id_map, arg::TypedArgument{T}) where T + @assert !(T <: Thunk) "Cannot use `Thunk`s in `@spawn`/`spawn`" + if T <: DTask && haskey(id_map, (value(arg)::DTask).uid) + #=FIXME:UNIQUE=# + return Sch.ThunkID(id_map[value(arg).uid], value(arg).thunk_ref) + end + return arg +end +function eager_process_args_submission_to_local!(id_map, spec::DTaskSpec{false}) for arg in spec.fargs eager_process_elem_submission_to_local!(id_map, arg) end end -function eager_process_args_submission_to_local!(id_map, spec_pairs::Vector{Pair{DTaskSpec,DTask}}) - for spec_pair in spec_pairs - eager_process_args_submission_to_local!(id_map, spec_pair) - end +function eager_process_args_submission_to_local(id_map, spec::DTaskSpec{true}) + return ntuple(i->eager_process_elem_submission_to_local(id_map, spec.fargs[i]), length(spec.fargs)) end -function DTaskMetadata(spec::DTaskSpec) - f = value(spec.fargs[1]) +DTaskMetadata(spec::DTaskSpec) = DTaskMetadata(eager_metadata(spec.fargs)) +function eager_metadata(fargs) + f = value(fargs[1]) f = f isa StreamingFunction ? f.f : f - arg_types = ntuple(i->chunktype(value(spec.fargs[i+1])), length(spec.fargs)-1) - return_type = Base.promote_op(f, arg_types...) - return DTaskMetadata(return_type) + arg_types = ntuple(i->chunktype(value(fargs[i+1])), length(fargs)-1) + return Base.promote_op(f, arg_types...) end function eager_spawn(spec::DTaskSpec) @@ -298,48 +303,64 @@ end chunktype(t::DTask) = t.metadata.return_type -function eager_launch!((spec, task)::Pair{DTaskSpec,DTask}) +function eager_launch!(pair::DTaskPair) + spec = pair.spec + task = pair.task + # Assign a name, if specified eager_assign_name!(spec, task) # Lookup DTask -> ThunkID - lock(Sch.EAGER_ID_MAP) do id_map - eager_process_args_submission_to_local!(id_map, spec=>task) + fargs = lock(Sch.EAGER_ID_MAP) do id_map + if is_typed(spec) + return Argument[map(Argument, eager_process_args_submission_to_local(id_map, spec))...] + else + eager_process_args_submission_to_local!(id_map, spec) + return spec.fargs + end end # Submit the task #=FIXME:REALLOC=# thunk_id = eager_submit!(PayloadOne(task.uid, task.future, - spec.fargs, spec.options, true)) + fargs, spec.options, true)) task.thunk_ref = thunk_id.ref end -function eager_launch!(specs::Vector{Pair{DTaskSpec,DTask}}) - ntasks = length(specs) +# FIXME: Don't convert Tuple to Vector{Argument} +function eager_launch!(pairs::Vector{DTaskPair}) + ntasks = length(pairs) # Assign a name, if specified - for (spec, task) in specs - eager_assign_name!(spec, task) + for pair in pairs + eager_assign_name!(pair.spec, pair.task) end #=FIXME:REALLOC_N=# - uids = [task.uid for (_, task) in specs] - futures = [task.future for (_, task) in specs] + uids = [pair.task.uid for pair in pairs] + futures = [pair.task.future for pair in pairs] # Get all functions, args/kwargs, and options #=FIXME:REALLOC_N=# all_fargs = lock(Sch.EAGER_ID_MAP) do id_map # Lookup DTask -> ThunkID - eager_process_args_submission_to_local!(id_map, specs) - [spec.fargs for (spec, _) in specs] + return map(pairs) do pair + spec = pair.spec + if is_typed(spec) + return Argument[map(Argument, eager_process_args_submission_to_local(id_map, spec))...] + else + eager_process_args_submission_to_local!(id_map, spec) + return spec.fargs + end + end end - all_options = Options[spec.options for (spec, _) in specs] + all_options = Options[pair.spec.options for pair in pairs] # Submit the tasks #=FIXME:REALLOC=# thunk_ids = eager_submit!(PayloadMulti(ntasks, uids, futures, all_fargs, all_options, true)) for i in 1:ntasks - task = specs[i][2] + task = pairs[i].task task.thunk_ref = thunk_ids[i].ref end end diff --git a/src/thunk.jl b/src/thunk.jl index 482d6620..e13e299f 100644 --- a/src/thunk.jl +++ b/src/thunk.jl @@ -17,8 +17,6 @@ function unset!(spec::ThunkSpec, _) spec.id = 0 spec.cache_ref = nothing spec.affinity = nothing - compute_scope = DefaultScope() - result_scope = AnyScope() spec.options = nothing end @@ -186,21 +184,19 @@ function args_kwargs_to_arguments(f, args, kwargs) end return args_kwargs end -function args_kwargs_to_arguments(f, args) - @nospecialize f args - args_kwargs = Argument[] - push!(args_kwargs, Argument(ArgPosition(true, 0, :NULL), f)) - pos_ctr = 1 - for idx in 1:length(args) - pos, arg = args[idx]::Pair - if pos === nothing - push!(args_kwargs, Argument(pos_ctr, arg)) - pos_ctr += 1 +function args_kwargs_to_typedarguments(f, args, kwargs) + nargs = 1 + length(args) + length(kwargs) + return ntuple(nargs) do idx + if idx == 1 + return TypedArgument(ArgPosition(true, 0, :NULL), f) + elseif idx in 2:(1+length(args)) + arg = args[idx-1] + return TypedArgument(idx, arg) else - push!(args_kwargs, Argument(pos, arg)) + kw, value = kwargs[idx-length(args)-1] + return TypedArgument(kw, value) end end - return args_kwargs end """ @@ -491,7 +487,11 @@ function _par(mod, ex::Expr; lazy=true, recur=true, opts=()) @gensym result return quote let - $result = $spawn($f, $Options(;$(opts...)), $(args...); $(kwargs...)) + $result = if $get_task_typed() + $typed_spawn($f, $Options(;$(opts...)), $(args...); $(kwargs...)) + else + $spawn($f, $Options(;$(opts...)), $(args...); $(kwargs...)) + end if $(Expr(:islocal, sync_var)) put!($sync_var, schedule(Task(()->fetch($result; raw=true)))) end @@ -516,6 +516,9 @@ function _setindex!_return_value(A, value, idxs...) return value end +const TASK_TYPED = ScopedValue{Bool}(false) +get_task_typed() = TASK_TYPED[] + """ Dagger.spawn(f, args...; kwargs...) -> DTask @@ -526,6 +529,36 @@ Spawns a `DTask` that will call `f(args...; kwargs...)`. Also supports passing a function spawn(f, args...; kwargs...) @nospecialize f args kwargs + # Merge all passed options + if length(args) >= 1 && first(args) isa Options + # N.B. Make a defensive copy in case user aliases Options struct + task_options = copy(first(args)::Options) + args = args[2:end] + else + task_options = Options() + end + + # Process the args and kwargs into Argument form + args_kwargs = args_kwargs_to_arguments(f, args, kwargs) + + return _spawn(args_kwargs, task_options) +end +function typed_spawn(f, args...; kwargs...) + # Merge all passed options + if length(args) >= 1 && first(args) isa Options + # N.B. Make a defensive copy in case user aliases Options struct + task_options = copy(first(args)::Options) + args = args[2:end] + else + task_options = Options() + end + + # Process the args and kwargs into Tuple of TypedArgument form + args_kwargs = args_kwargs_to_typedarguments(f, args, kwargs) + + return _spawn(args_kwargs, task_options) +end +function _spawn(args_kwargs, task_options) # Get all scoped options and determine which propagate beyond this task scoped_options = get_options()::NamedTuple if haskey(scoped_options, :propagates) @@ -539,20 +572,9 @@ function spawn(f, args...; kwargs...) end append!(propagates, keys(scoped_options)::NTuple{N,Symbol} where N) - # Merge all passed options - if length(args) >= 1 && first(args) isa Options - # N.B. Make a defensive copy in case user aliases Options struct - task_options = copy(first(args)::Options) - args = args[2:end] - else - task_options = Options() - end # N.B. Merges into task_options options_merge!(task_options, scoped_options; override=false) - # Process the args and kwargs into Pair form - args_kwargs = args_kwargs_to_arguments(f, args, kwargs) - # Get task queue, and don't let it propagate task_queue = get(scoped_options, :task_queue, DefaultTaskQueue())::AbstractTaskQueue filter!(prop -> prop != :task_queue, propagates) @@ -568,7 +590,7 @@ function spawn(f, args...; kwargs...) task = eager_spawn(spec) # Enqueue the task into the task queue - enqueue!(task_queue, spec=>task) + enqueue!(task_queue, DTaskPair(spec, task)) return task end diff --git a/src/utils/dagdebug.jl b/src/utils/dagdebug.jl index 61503040..873e47e7 100644 --- a/src/utils/dagdebug.jl +++ b/src/utils/dagdebug.jl @@ -35,3 +35,28 @@ macro dagdebug(thunk, category, msg, args...) end end) end + +# FIXME: Calculate fast-growth based on clock time, not iteration +const OPCOUNTER_CATEGORIES = Symbol[] +const OPCOUNTER_FAST_GROWTH_THRESHOLD = Ref(10_000_000) +struct OpCounter + value::Threads.Atomic{Int} +end +OpCounter() = OpCounter(Threads.Atomic{Int}(0)) +macro opcounter(category, count=1) + cat_sym = category.value + @gensym old + opcounter_sym = Symbol(:OPCOUNTER_, cat_sym) + if !isdefined(__module__, opcounter_sym) + __module__.eval(:(#=const=# $opcounter_sym = OpCounter())) + end + esc(quote + if $(QuoteNode(cat_sym)) in $OPCOUNTER_CATEGORIES + $old = Threads.atomic_add!($opcounter_sym.value, Int($count)) + if $old > 1 && (mod1($old, $OPCOUNTER_FAST_GROWTH_THRESHOLD[]) == 1 || $count > $OPCOUNTER_FAST_GROWTH_THRESHOLD[]) + println("Fast-growing counter: $($(QuoteNode(cat_sym))) = $($old)") + end + end + end) +end +opcounter(mod::Module, category::Symbol) = getfield(mod, Symbol(:OPCOUNTER_, category)).value[] \ No newline at end of file diff --git a/src/utils/interval_tree.jl b/src/utils/interval_tree.jl new file mode 100644 index 00000000..e67f66b2 --- /dev/null +++ b/src/utils/interval_tree.jl @@ -0,0 +1,363 @@ +mutable struct IntervalNode{M,E} + span::M + max_end::E # Maximum end value in this subtree + left::Union{IntervalNode{M,E}, Nothing} + right::Union{IntervalNode{M,E}, Nothing} + + IntervalNode(span::M) where M <: MemorySpan = new{M,UInt64}(span, span_end(span), nothing, nothing) + IntervalNode(span::LocalMemorySpan) = new{LocalMemorySpan,UInt64}(span, span_end(span), nothing, nothing) + IntervalNode(span::ManyMemorySpan{N}) where N = new{ManyMemorySpan{N},ManyPair{N}}(span, span_end(span), nothing, nothing) + IntervalNode(span::LocatorMemorySpan{T}) where T = new{LocatorMemorySpan{T},UInt64}(span, span_end(span), nothing, nothing) +end + +mutable struct IntervalTree{M,E} + root::Union{IntervalNode{M,E}, Nothing} + + IntervalTree{M}() where M<:MemorySpan = new{M,UInt64}(nothing) + IntervalTree{LocalMemorySpan}() = new{LocalMemorySpan,UInt64}(nothing) + IntervalTree{ManyMemorySpan{N}}() where N = new{ManyMemorySpan{N},ManyPair{N}}(nothing) + IntervalTree{LocatorMemorySpan{T}}() where T = new{LocatorMemorySpan{T},UInt64}(nothing) +end + +# Construct interval tree from unsorted set of spans +function IntervalTree{M}(spans) where M + tree = IntervalTree{M}() + for span in spans + insert!(tree, span) + end + return tree +end +IntervalTree(spans::Vector{M}) where M = IntervalTree{M}(spans) + +function Base.show(io::IO, tree::IntervalTree) + println(io, "$(typeof(tree)) (with $(length(tree)) spans):") + for (i, span) in enumerate(tree) + println(io, " $i: [$(span_start(span)), $(span_end(span))) (len=$(span_len(span)))") + end +end + +function Base.collect(tree::IntervalTree{M}) where M + result = M[] + for span in tree + push!(result, span) + end + return result +end + +function Base.iterate(tree::IntervalTree{M}) where M + state = Vector{M}() + if tree.root === nothing + return nothing + end + return iterate(tree.root) +end +function Base.iterate(tree::IntervalTree, state) + return iterate(tree.root, state) +end +function Base.iterate(root::IntervalNode{M,E}) where {M,E} + state = Vector{IntervalNode{M,E}}() + push!(state, root) + return iterate(root, state) +end +function Base.iterate(root::IntervalNode, state) + if isempty(state) + return nothing + end + current = popfirst!(state) + if current.right !== nothing + pushfirst!(state, current.right) + end + if current.left !== nothing + pushfirst!(state, current.left) + end + return current.span, state +end + +function Base.length(tree::IntervalTree) + result = 0 + for _ in tree + result += 1 + end + return result +end + +# Update max_end value for a node based on its children +function update_max_end!(node::IntervalNode) + max_end = span_end(node.span) + if node.left !== nothing + max_end = max(max_end, node.left.max_end) + end + if node.right !== nothing + max_end = max(max_end, node.right.max_end) + end + node.max_end = max_end +end + +# Insert a span into the interval tree +function Base.insert!(tree::IntervalTree{M,E}, span::M) where {M,E} + if !isempty(span) + if tree.root === nothing + tree.root = IntervalNode(span) + update_max_end!(tree.root) + return span + end + #tree.root = insert_node!(tree.root, span) + to_update = Vector{IntervalNode{M,E}}() + prev_node = tree.root + cur_node = tree.root + while cur_node !== nothing + if span_start(span) <= span_start(cur_node.span) + cur_node = cur_node.left + else + cur_node = cur_node.right + end + if cur_node !== nothing + prev_node = cur_node + push!(to_update, cur_node) + end + end + if prev_node.left === nothing + prev_node.left = IntervalNode(span) + else + prev_node.right = IntervalNode(span) + end + for node_idx in eachindex(to_update) + node = to_update[node_idx] + update_max_end!(node) + end + end + return span +end + +function insert_node!(::Nothing, span::M) where M + return IntervalNode(span) +end +function insert_node!(node::IntervalNode{M,E}, span::M) where {M,E} + if span_start(span) <= span_start(node.span) + node.left = insert_node!(node.left, span) + else + node.right = insert_node!(node.right, span) + end + + update_max_end!(node) + return node +end + +# Remove a specific span from the tree (split as needed) +function Base.delete!(tree::IntervalTree{M}, span::M) where M + if !isempty(span) + tree.root = delete_node!(tree.root, span) + end + return span +end + +function delete_node!(::Nothing, span::M) where M + return nothing +end +function delete_node!(node::IntervalNode{M,E}, span::M) where {M,E} + # Check for exact match first + if span_start(node.span) == span_start(span) && span_len(node.span) == span_len(span) + # Exact match, remove the node + if node.left === nothing && node.right === nothing + return nothing + elseif node.left === nothing + return node.right + elseif node.right === nothing + return node.left + else + # Node has two children - replace with inorder successor + successor = find_min(node.right) + node.span = successor.span + node.right = delete_node!(node.right, successor.span) + end + # Check for overlap + elseif spans_overlap(node.span, span) + # Handle overlapping spans by removing current node and adding remainders + original_span = node.span + + # Remove the current node first (same logic as exact match) + if node.left === nothing && node.right === nothing + # Leaf node - remove it and create a new subtree with remainders + remaining_node = nothing + elseif node.left === nothing + remaining_node = node.right + elseif node.right === nothing + remaining_node = node.left + else + # Node has two children - replace with inorder successor + successor = find_min(node.right) + node.span = successor.span + node.right = delete_node!(node.right, successor.span) + remaining_node = node + end + + # Calculate and insert the remaining portions + original_start = span_start(original_span) + original_end = span_end(original_span) + del_start = span_start(span) + del_end = span_end(span) + + # Left portion: exists if original starts before deleted span + if original_start < del_start + left_end = min(original_end, del_start) + if left_end > original_start + left_span = M(original_start, left_end - original_start) + if !isempty(left_span) + remaining_node = insert_node!(remaining_node, left_span) + end + end + end + + # Right portion: exists if original extends beyond deleted span + if original_end > del_end + right_start = max(original_start, del_end) + if original_end > right_start + right_span = M(right_start, original_end - right_start) + if !isempty(right_span) + remaining_node = insert_node!(remaining_node, right_span) + end + end + end + + return remaining_node + elseif span_start(span) <= span_start(node.span) + node.left = delete_node!(node.left, span) + else + node.right = delete_node!(node.right, span) + end + + if node !== nothing + update_max_end!(node) + end + return node +end + +function find_min(node::IntervalNode) + while node.left !== nothing + node = node.left + end + return node +end + +# Find all spans that overlap with the given query span +function find_overlapping(tree::IntervalTree{M}, query::M; exact::Bool=true) where M + result = M[] + find_overlapping!(tree.root, query, result; exact) + return result +end +function find_overlapping!(tree::IntervalTree{M}, query::M, result::Vector{M}; exact::Bool=true) where M + find_overlapping!(tree.root, query, result; exact) + return result +end + +function find_overlapping!(::Nothing, query::M, result::Vector{M}; exact::Bool=true) where M + return +end +function find_overlapping!(node::IntervalNode{M,E}, query::M, result::Vector{M}; exact::Bool=true) where {M,E} + # Check if current node overlaps with query + if spans_overlap(node.span, query) + if exact + # Get the overlapping portion of the span + overlap_start = max(span_start(node.span), span_start(query)) + overlap_end = min(span_end(node.span), span_end(query)) + overlap = M(overlap_start, overlap_end - overlap_start) + push!(result, overlap) + else + push!(result, node.span) + end + end + + # Recursively search left subtree if it might contain overlapping intervals + if node.left !== nothing && node.left.max_end > span_start(query) + find_overlapping!(node.left, query, result; exact) + end + + # Recursively search right subtree if query extends beyond current node's start + if node.right !== nothing && span_end(query) > span_start(node.span) + find_overlapping!(node.right, query, result; exact) + end +end + +# ============================================================================ +# MAIN SUBTRACTION ALGORITHM +# ============================================================================ + +""" + subtract_spans!(minuend_tree::IntervalTree{M}, subtrahend_spans::Vector{M}, diff=nothing) where M + +Subtract all spans in subtrahend_spans from the minuend_tree in-place. +The minuend_tree is modified to contain only the portions that remain after subtraction. + +Time Complexity: O(M log N + M*K) where M = |subtrahend_spans|, N = |minuend nodes|, + K = average overlaps per subtrahend span +Space Complexity: O(1) additional space (modifies tree in-place) + +If `diff` is provided, add the overlapping spans to `diff`. +""" +function subtract_spans!(minuend_tree::IntervalTree{M}, subtrahend_spans::Vector{M}, diff=nothing) where M + for sub_span in subtrahend_spans + subtract_single_span!(minuend_tree, sub_span, diff) + end +end + +""" + subtract_single_span!(tree::IntervalTree, sub_span::MemorySpan, diff=nothing) + +Subtract a single span from the interval tree. This function: +1. Finds all overlapping spans in the tree +2. Removes each overlapping span +3. Adds back the non-overlapping portions (left and/or right remnants) +4. If diff is provided, add the overlapping span to diff +""" +function subtract_single_span!(tree::IntervalTree{M}, sub_span::M, diff=nothing) where M + # Find all spans that overlap with the subtrahend + overlapping_spans = find_overlapping(tree, sub_span) + + # Process each overlapping span + for overlap_span in overlapping_spans + # Remove the overlapping span from the tree + delete!(tree, overlap_span) + + # Calculate and add back the portions that should remain + add_remaining_portions!(tree, overlap_span, sub_span) + + if diff !== nothing && !isempty(overlap_span) + push!(diff, overlap_span) + end + end +end + +""" + add_remaining_portions!(tree::IntervalTree, original::MemorySpan, subtracted::MemorySpan) + +After removing an overlapping span, add back the portions that don't overlap with the subtracted span. +There can be up to two remaining portions: left and right of the subtracted region. +""" +function add_remaining_portions!(tree::IntervalTree{M}, original::M, subtracted::M) where M + original_start = span_start(original) + original_end = span_end(original) + sub_start = span_start(subtracted) + sub_end = span_end(subtracted) + + # Left portion: exists if original starts before subtracted + if original_start < sub_start + left_end = min(original_end, sub_start) + if left_end > original_start + left_span = M(original_start, left_end - original_start) + if !isempty(left_span) + insert!(tree, left_span) + end + end + end + + # Right portion: exists if original extends beyond subtracted + if original_end > sub_end + right_start = max(original_start, sub_end) + if original_end > right_start + right_span = M(right_start, original_end - right_start) + if !isempty(right_span) + insert!(tree, right_span) + end + end + end +end \ No newline at end of file diff --git a/src/utils/memory-span.jl b/src/utils/memory-span.jl new file mode 100644 index 00000000..91f291cb --- /dev/null +++ b/src/utils/memory-span.jl @@ -0,0 +1,98 @@ +### Remote pointer type + +struct RemotePtr{T,S<:MemorySpace} <: Ref{T} + addr::UInt + space::S +end +RemotePtr{T}(addr::UInt, space::S) where {T,S} = RemotePtr{T,S}(addr, space) +RemotePtr{T}(ptr::Ptr{V}, space::S) where {T,V,S} = RemotePtr{T,S}(UInt(ptr), space) +RemotePtr{T}(ptr::Ptr{V}) where {T,V} = RemotePtr{T}(UInt(ptr), CPURAMMemorySpace(myid())) +# FIXME: Don't hardcode CPURAMMemorySpace +RemotePtr(addr::UInt) = RemotePtr{Cvoid}(addr, CPURAMMemorySpace(myid())) +Base.convert(::Type{RemotePtr}, x::Ptr{T}) where T = + RemotePtr(UInt(x), CPURAMMemorySpace(myid())) +Base.convert(::Type{<:RemotePtr{V}}, x::Ptr{T}) where {V,T} = + RemotePtr{V}(UInt(x), CPURAMMemorySpace(myid())) +Base.convert(::Type{UInt}, ptr::RemotePtr) = ptr.addr +Base.:+(ptr::RemotePtr{T}, offset::Integer) where T = RemotePtr{T}(ptr.addr + offset, ptr.space) +Base.:-(ptr::RemotePtr{T}, offset::Integer) where T = RemotePtr{T}(ptr.addr - offset, ptr.space) +function Base.isless(ptr1::RemotePtr, ptr2::RemotePtr) + @assert ptr1.space == ptr2.space + return ptr1.addr < ptr2.addr +end + +### Generic memory spans + +struct MemorySpan{S} + ptr::RemotePtr{Cvoid,S} + len::UInt +end +MemorySpan(ptr::RemotePtr{Cvoid,S}, len::Integer) where S = + MemorySpan{S}(ptr, UInt(len)) +MemorySpan{S}(addr::UInt, len::Integer) where S = + MemorySpan{S}(RemotePtr{Cvoid,S}(addr), UInt(len)) +Base.isless(a::MemorySpan, b::MemorySpan) = a.ptr < b.ptr +Base.isempty(x::MemorySpan) = x.len == 0 +span_start(span::MemorySpan) = span.ptr.addr +span_len(span::MemorySpan) = span.len +span_end(span::MemorySpan) = span.ptr.addr + span.len +spans_overlap(span1::MemorySpan, span2::MemorySpan) = + span_start(span1) < span_end(span2) && span_start(span2) < span_end(span1) + +### More space-efficient memory spans + +struct LocalMemorySpan + ptr::UInt + len::UInt +end +LocalMemorySpan(span::MemorySpan) = LocalMemorySpan(span.ptr.addr, span.len) +Base.isempty(x::LocalMemorySpan) = x.len == 0 +span_start(span::LocalMemorySpan) = span.ptr +span_len(span::LocalMemorySpan) = span.len +span_end(span::LocalMemorySpan) = span.ptr + span.len +spans_overlap(span1::LocalMemorySpan, span2::LocalMemorySpan) = + span_start(span1) < span_end(span2) && span_start(span2) < span_end(span1) + +# FIXME: Store the length separately, since it's shared by all spans +struct ManyMemorySpan{N} + spans::NTuple{N,LocalMemorySpan} +end +Base.isempty(x::ManyMemorySpan) = all(isempty, x.spans) +span_start(span::ManyMemorySpan{N}) where N = ManyPair(ntuple(i -> span_start(span.spans[i]), N)) +span_len(span::ManyMemorySpan{N}) where N = ManyPair(ntuple(i -> span_len(span.spans[i]), N)) +span_end(span::ManyMemorySpan{N}) where N = ManyPair(ntuple(i -> span_end(span.spans[i]), N)) +spans_overlap(span1::ManyMemorySpan{N}, span2::ManyMemorySpan{N}) where N = + # N.B. The spans are assumed to be the same length and relative offset + spans_overlap(span1.spans[1], span2.spans[1]) + +struct ManyPair{N} <: Unsigned + pairs::NTuple{N,UInt} +end +Base.promote_rule(::Type{ManyPair}, ::Type{T}) where {T<:Integer} = ManyPair +Base.convert(::Type{ManyPair{N}}, x::T) where {T<:Integer,N} = ManyPair(ntuple(i -> x, N)) +Base.convert(::Type{ManyPair}, x::ManyPair) = x +Base.:+(x::ManyPair{N}, y::ManyPair{N}) where N = ManyPair(ntuple(i -> x.pairs[i] + y.pairs[i], N)) +Base.:-(x::ManyPair{N}, y::ManyPair{N}) where N = ManyPair(ntuple(i -> x.pairs[i] - y.pairs[i], N)) +Base.:-(x::ManyPair) = error("Can't negate a ManyPair") +Base.:(==)(x::ManyPair, y::ManyPair) = x.pairs == y.pairs +Base.isless(x::ManyPair, y::ManyPair) = x.pairs[1] < y.pairs[1] +Base.:(<)(x::ManyPair, y::ManyPair) = x.pairs[1] < y.pairs[1] +Base.string(x::ManyPair) = "ManyPair($(x.pairs))" + +ManyMemorySpan{N}(start::ManyPair{N}, len::ManyPair{N}) where N = + ManyMemorySpan{N}(ntuple(i -> LocalMemorySpan(start.pairs[i], len.pairs[i]), N)) + +### Memory spans with ownership info + +struct LocatorMemorySpan{T} + span::LocalMemorySpan + owner::T +end +LocatorMemorySpan{T}(start::UInt64, len::UInt64) where T = # For interval tree + LocatorMemorySpan{T}(LocalMemorySpan(start, len), 0) +Base.isempty(x::LocatorMemorySpan) = span_len(x.span) == 0 +span_start(x::LocatorMemorySpan) = span_start(x.span) +span_end(x::LocatorMemorySpan) = span_end(x.span) +span_len(x::LocatorMemorySpan) = span_len(x.span) +spans_overlap(span1::LocatorMemorySpan{T}, span2::LocatorMemorySpan{T}) where T = + spans_overlap(span1.span, span2.span) \ No newline at end of file