Skip to content

Commit e63c858

Browse files
committed
Assorted add_edges fixes
1 parent 853c3b9 commit e63c858

File tree

3 files changed

+88
-58
lines changed

3 files changed

+88
-58
lines changed

lib/DaggerGraphs/src/adjlist.jl

Lines changed: 45 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,17 @@ abstract type AbstractAdjListStorage{T,D} end
44
Base.IteratorSize(::Type{<:AbstractAdjListStorage}) = Base.HasLength()
55
Base.IteratorEltype(::Type{<:AbstractAdjListStorage}) = Base.HasEltype()
66
Base.eltype(::Type{<:AbstractAdjListStorage{T}}) where T = Edge{T}
7+
function add_edges!(adjlist::AbstractAdjListStorage, edges; all::Bool=true)
8+
count = 0
9+
for edge in edges
10+
if add_edge!(adjlist, edge)
11+
count += 1
12+
elseif all
13+
return count
14+
end
15+
end
16+
return count
17+
end
718

819
# Storage matching Graphs.SimpleGraph for high edge counts
920
struct SimpleAdjListStorage{T,D} <: AbstractAdjListStorage{T,D}
@@ -46,12 +57,14 @@ function Base.iterate(adjlist::SimpleAdjListStorage{T}, state::Tuple{T,T}) where
4657
dst += 1
4758
return (Edge(value), (src, dst))
4859
end
49-
function Base.push!(adjlist::SimpleAdjListStorage{T,D}, edge) where {T,D}
60+
function Graphs.add_edge!(adjlist::SimpleAdjListStorage{T,D}, edge) where {T,D}
5061
src, dst = Tuple(edge)
5162
if !D
5263
src, dst = (min(src, dst), max(src, dst))
5364
end
5465

66+
has_edge(adjlist, edge) && return false
67+
5568
# If necessary, allocate more inner vectors
5669
nv = max(src, dst)
5770
if nv > length(adjlist.fadjlist)
@@ -72,9 +85,9 @@ function Base.push!(adjlist::SimpleAdjListStorage{T,D}, edge) where {T,D}
7285
push!(adjlist.fadjlist[src], dst)
7386
end
7487

75-
return adjlist
88+
return true
7689
end
77-
function Base.in(edge, adjlist::SimpleAdjListStorage{T,D}) where {T,D}
90+
function Graphs.has_edge(adjlist::SimpleAdjListStorage{T,D}, edge) where {T,D}
7891
src, dst = Tuple(edge)
7992
if !D
8093
src, dst = (min(src, dst), max(src, dst))
@@ -101,11 +114,29 @@ function Base.iterate(adjlist::SparseAdjListStorage{T}, state=one(T)) where T
101114
value = adjlist.adjlist[state]
102115
return (Edge(value), state+one(T))
103116
end
104-
function Base.push!(adjlist::SparseAdjListStorage, edge)
117+
function Graphs.add_edge!(adjlist::SparseAdjListStorage, edge)
118+
if findfirst(==(Tuple(edge)), adjlist.adjlist) !== nothing
119+
return false
120+
end
105121
push!(adjlist.adjlist, Tuple(edge))
106-
return adjlist
122+
return true
123+
end
124+
function add_edges!(adjlist::SparseAdjListStorage, edges; all::Bool=true)
125+
# FIXME: Account for non-directedness
126+
edge_set = Set(map(Tuple, edges))
127+
for edge in adjlist.adjlist
128+
if edge in edge_set
129+
if all
130+
return 0
131+
else
132+
pop!(edge_set, edge)
133+
end
134+
end
135+
end
136+
append!(adjlist.adjlist, collect(edge_set))
137+
return length(edge_set)
107138
end
108-
function Base.in(edge, adjlist::SparseAdjListStorage{T,D}) where {T,D}
139+
function Graphs.has_edge(adjlist::SparseAdjListStorage{T,D}, edge) where {T,D}
109140
src, dst = Tuple(edge)
110141
if !D
111142
src, dst = (min(src, dst), max(src, dst))
@@ -120,26 +151,18 @@ struct AdjList{T,D,A<:AbstractAdjListStorage{T,D}}
120151
end
121152
AdjList{T,D}(adjlist::AbstractAdjListStorage{T,D}) where {T,D} =
122153
AdjList{T,D,typeof(adjlist)}(adjlist)
123-
AdjList{T,D}() where {T,D} = AdjList{T,D}(SimpleAdjListStorage{T,D}())
154+
# TODO: AdjList{T,D}() where {T,D} = AdjList{T,D}(SimpleAdjListStorage{T,D}())
155+
AdjList{T,D}() where {T,D} = AdjList{T,D}(SparseAdjListStorage{T,D}())
124156
AdjList() = AdjList{Int,true}()
125157
Base.copy(adj::AdjList{T,D,A}) where {T,D,A} = AdjList{T,D,A}(copy(adj.data))
126-
Base.in(adj::AdjList{T,D}, edge) where {T,D} = edge in adj.data
158+
Graphs.ne(adj::AdjList) = length(adj.data) # TODO: Use ne()
159+
Graphs.has_edge(adj::AdjList{T}, src::Integer, dst::Integer) where T =
160+
has_edge(adj.data, Edge{T}(src, dst))
161+
Graphs.has_edge(adj::AdjList{T,D}, edge) where {T,D} = has_edge(adj.data, edge)
127162
Graphs.add_edge!(adj::AdjList{T}, src::Integer, dst::Integer) where T =
128163
add_edge!(adj, Edge{T}(src, dst))
129-
function Graphs.add_edge!(adj::AdjList, edge)
130-
if edge in adj.data
131-
return false
132-
end
133-
push!(adj.data, edge)
134-
return true
135-
end
136-
function add_edges!(g::AdjList, edges)
137-
for edge in edges
138-
src, dst = Tuple(edge)
139-
add_edge!(g, src, dst) || return false
140-
end
141-
return true
142-
end
164+
Graphs.add_edge!(adj::AdjList, edge) = add_edge!(adj.data, edge)
165+
add_edges!(adj::AdjList, edges; all::Bool=true) = add_edges!(adj.data, edges; all)
143166
Graphs.edges(adj::AdjList) = copy(adj.data)
144167
function Graphs.inneighbors(adj::AdjList, v::Integer)
145168
neighbors = Int[]

lib/DaggerGraphs/src/dgraph.jl

Lines changed: 42 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -87,26 +87,32 @@ function DGraph(dg::DGraph{T,D}; chunksize::Integer=0, directed::Bool=D, freeze:
8787
freeze && freeze!(g)
8888
return g
8989
end
90-
function with_state(g::DGraph, f, args...)
90+
function with_state(g::DGraph, f, args...; kwargs...)
9191
if g.frozen[]
9292
@assert !any(x->x isa ELTYPE, args)
93-
return f(g.state, args...)
93+
return f(g.state, args...; kwargs...)
9494
else
95-
return fetch(Dagger.@spawn f(g.state, args...))
95+
return fetch(Dagger.@spawn f(g.state, args...; kwargs...))
9696
end
9797
end
98-
function exec_fast(f, args...; fetch::Bool=true)
98+
function exec_fast(f, args...; kwargs...)
9999
# FIXME: Ensure that `EagerThunk` result is also local
100100
if any(x->(x isa Dagger.EagerThunk && !isready(x)) ||
101101
(x isa Dagger.Chunk && x.handle.owner != myid()), args)
102-
if fetch
103-
return Base.fetch(Dagger.@spawn f(args...))
104-
else
105-
return Dagger.@spawn f(args...)
106-
end
102+
return Base.fetch(Dagger.@spawn f(args...; kwargs...))
103+
else
104+
fetched_args = ntuple(i->args[i] isa ELTYPE ? Base.fetch(args[i]) : args[i], length(args))
105+
return f(fetched_args...; kwargs...)
106+
end
107+
end
108+
function exec_fast_nofetch(f, args...; kwargs...)
109+
# FIXME: Ensure that `EagerThunk` result is also local
110+
if any(x->(x isa Dagger.EagerThunk && !isready(x)) ||
111+
(x isa Dagger.Chunk && x.handle.owner != myid()), args)
112+
return Dagger.@spawn f(args...; kwargs...)
107113
else
108114
fetched_args = ntuple(i->args[i] isa ELTYPE ? Base.fetch(args[i]) : args[i], length(args))
109-
return f(fetched_args...)
115+
return f(fetched_args...; kwargs...)
110116
end
111117
end
112118

@@ -172,7 +178,7 @@ function Graphs.has_edge(g::DGraphState{T,D}, src::Integer, dst::Integer) where
172178
else
173179
# The edge will be in an AdjList
174180
adj = g.bg_adjs[src_part_idx]
175-
return exec_fast(Base.in, adj, (src, dst))
181+
return exec_fast(has_edge, adj, src, dst)
176182
end
177183
end
178184
Graphs.is_directed(::DGraph{T,D}) where {T,D} = D
@@ -235,12 +241,13 @@ function add_partition!(g::DGraph, sg::AbstractGraph)
235241
check_not_frozen(g)
236242
return with_state(g, add_partition!, sg)
237243
end
238-
function add_partition!(g::DGraphState{T,D}, sg::AbstractGraph) where {T,D}
244+
function add_partition!(g::DGraphState{T,D}, sg::AbstractGraph; all::Bool=true) where {T,D}
239245
check_not_frozen(g)
240246
shift = nv(g)
241247
part = add_partition!(g, nv(sg))
242248
part_edges = map(edge->(src(edge)+shift, dst(edge)+shift), collect(edges(sg)))
243-
@assert add_edges!(g, part_edges)
249+
count = add_edges!(g, part_edges; all)
250+
@assert !all || count == length(part_edges)
244251
return part
245252
end
246253
function Graphs.add_edge!(g::DGraph, src::Integer, dst::Integer)
@@ -292,17 +299,19 @@ function Graphs.add_edge!(g::DGraphState{T,D}, src::Integer, dst::Integer) where
292299

293300
return true
294301
end
295-
function add_edges!(g::DGraph, iter)
302+
function add_edges!(g::DGraph, iter; all::Bool=true)
296303
check_not_frozen(g)
297-
return with_state(g, add_edges!, iter)
304+
return with_state(g, add_edges!, iter; all)
298305
end
299-
function add_edges!(g::DGraphState{T,D}, iter) where {T,D}
306+
function add_edges!(g::DGraphState{T,D}, iter; all::Bool=true) where {T,D}
300307
check_not_frozen(g)
301308

302309
# Determine edge partition/background
303310
part_edges = Dict{Int,Vector{Tuple{T,T}}}(part=>Tuple{T,T}[] for part in 1:nparts(g))
304311
back_edges = Dict{Int,Vector{Tuple{T,T}}}(part=>Tuple{T,T}[] for part in 1:nparts(g))
312+
nedges = 0
305313
for edge in iter
314+
nedges += 1
306315
src, dst = Tuple(edge)
307316

308317
src_part_idx = findfirst(span->src in span, g.parts_nv)
@@ -320,34 +329,32 @@ function add_edges!(g::DGraphState{T,D}, iter) where {T,D}
320329
end
321330

322331
# Add edges concurrently
323-
part_tasks = [exec_fast(add_edges!, g.parts[part], g.parts_nv[part].start-1, edges; fetch=false) for (part, edges) in part_edges]
324-
back_tasks = [exec_fast(add_edges!, g.bg_adjs[part], edges; fetch=false) for (part, edges) in back_edges]
325-
326-
# Validate that all edges were successfully added
327-
if !all(fetch, part_tasks) || !all(fetch, back_tasks)
328-
return false
329-
end
332+
part_tasks = Dict(part=>exec_fast_nofetch(add_edges!, g.parts[part], g.parts_nv[part].start-1, edges; all) for (part, edges) in part_edges)
333+
back_tasks = Dict(part=>exec_fast_nofetch(add_edges!, g.bg_adjs[part], edges; all) for (part, edges) in back_edges)
330334

331335
# Update edge counters
332-
for (part, edges) in part_edges
333-
g.parts_ne[part] += length(edges)
336+
for (part, edge_count) in part_tasks
337+
g.parts_ne[part] += fetch(edge_count)
334338
end
335-
for (part, edges) in back_edges
336-
g.bg_adjs_ne_src[part] += length(edges)
337-
#= FIXME
338-
g.bg_adjs_ne[src_part_idx] += 1
339-
g.bg_adjs_ne[dst_part_idx] += 1
340-
=#
339+
for (part, edge_count) in back_tasks
340+
g.bg_adjs_ne_src[part] += fetch(edge_count)
341+
g.bg_adjs_ne[part] = exec_fast(ne, g.bg_adjs[part])
341342
end
342343

343-
return true
344+
# Validate that all edges were successfully added
345+
return sum(fetch, values(part_tasks)) + sum(fetch, values(back_tasks))
344346
end
345-
function add_edges!(g::Graphs.AbstractSimpleGraph, shift, edges)
347+
function add_edges!(g::Graphs.AbstractSimpleGraph, shift, edges; all::Bool=true)
348+
count = 0
346349
for edge in edges
347350
src, dst = Tuple(edge)
348-
add_edge!(g, src-shift, dst-shift) || return false
351+
if add_edge!(g, src-shift, dst-shift)
352+
count += 1
353+
elseif all
354+
return count
355+
end
349356
end
350-
return true
357+
return count
351358
end
352359
edge_owner(src::Int, dst::Int, src_part_idx::Int, dst_part_idx::Int) =
353360
iseven(hash(Base.unsafe_trunc(UInt, src+dst))) ? src_part_idx : dst_part_idx

lib/DaggerGraphs/src/io.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ function load_dir(dir::String; freeze::Bool=false)
1414
edges = readdlm(file)
1515
srcs = edges[:,1]
1616
dsts = edges[:,2]
17-
add_edges!(dg, zip(srcs, dsts))
17+
@assert add_edges!(dg, zip(srcs, dsts)) == size(edges, 1)
1818
end
1919
freeze && freeze!(dg)
2020
return dg

0 commit comments

Comments
 (0)