Skip to content

Commit d67e025

Browse files
committed
Add batched add_edges interface
1 parent c5196a3 commit d67e025

File tree

2 files changed

+69
-1
lines changed

2 files changed

+69
-1
lines changed

lib/DaggerGraphs/src/adjlist.jl

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -124,13 +124,22 @@ AdjList{T,D}() where {T,D} = AdjList{T,D}(SimpleAdjListStorage{T,D}())
124124
AdjList() = AdjList{Int,true}()
125125
Base.copy(adj::AdjList{T,D,A}) where {T,D,A} = AdjList{T,D,A}(copy(adj.data))
126126
Base.in(adj::AdjList{T,D}, edge) where {T,D} = edge in adj.data
127-
function Graphs.add_edge!(adj::AdjList{T,D}, edge) where {T,D}
127+
Graphs.add_edge!(adj::AdjList{T}, src::Integer, dst::Integer) where T =
128+
add_edge!(adj, Edge{T}(src, dst))
129+
function Graphs.add_edge!(adj::AdjList, edge)
128130
if edge in adj.data
129131
return false
130132
end
131133
push!(adj.data, edge)
132134
return true
133135
end
136+
function add_edges!(g::AdjList, edges)
137+
for edge in edges
138+
src, dst = Tuple(edge)
139+
add_edge!(g, src, dst) || return false
140+
end
141+
return true
142+
end
134143
Graphs.edges(adj::AdjList) = copy(adj.data)
135144
function Graphs.inneighbors(adj::AdjList, v::Integer)
136145
neighbors = Int[]

lib/DaggerGraphs/src/dgraph.jl

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,8 @@ function Base.show(io::IO, g::DGraph{T,D}) where {T,D}
141141
print(io, "{$(nv(g)), $(ne(g))} $(D ? "" : "un")directed Dagger $T graph$(isfrozen(g) ? " (frozen)" : "")")
142142
end
143143

144+
nparts(g::DGraph) = with_state(g, nparts)
145+
nparts(g::DGraphState) = length(g.parts)
144146
Base.eltype(::DGraph{T}) where T = T
145147
Graphs.edgetype(::DGraph{T}) where T = Tuple{T,T}
146148
Graphs.nv(g::DGraph) = with_state(g, nv)::Int
@@ -280,6 +282,63 @@ function Graphs.add_edge!(g::DGraphState{T,D}, src::Integer, dst::Integer) where
280282

281283
return true
282284
end
285+
function add_edges!(g::DGraph, iter)
286+
check_not_frozen(g)
287+
return with_state(g, add_edges!, iter)
288+
end
289+
function add_edges!(g::DGraphState{T,D}, iter) where {T,D}
290+
check_not_frozen(g)
291+
292+
# Determine edge partition/background
293+
part_edges = Dict{Int,Vector{Tuple{T,T}}}(part=>Tuple{T,T}[] for part in 1:nparts(g))
294+
back_edges = Dict{Int,Vector{Tuple{T,T}}}(part=>Tuple{T,T}[] for part in 1:nparts(g))
295+
for edge in iter
296+
src, dst = Tuple(edge)
297+
298+
src_part_idx = findfirst(span->src in span, g.parts_nv)
299+
@assert src_part_idx !== nothing "Source vertex $src does not exist"
300+
301+
dst_part_idx = findfirst(span->dst in span, g.parts_nv)
302+
@assert dst_part_idx !== nothing "Destination vertex $dst does not exist"
303+
304+
if src_part_idx == dst_part_idx
305+
push!(part_edges[src_part_idx], (src, dst))
306+
else
307+
owner_part_idx = D ? src_part_idx : edge_owner(src, dst, src_part_idx, dst_part_idx)
308+
push!(back_edges[owner_part_idx], (src, dst))
309+
end
310+
end
311+
312+
# Add edges concurrently
313+
part_tasks = [exec_fast(add_edges!, g.parts[part], g.parts_nv[part].start-1, edges; fetch=false) for (part, edges) in part_edges]
314+
back_tasks = [exec_fast(add_edges!, g.bg_adjs[part], edges; fetch=false) for (part, edges) in back_edges]
315+
316+
# Validate that all edges were successfully added
317+
if !all(fetch, part_tasks) || !all(fetch, back_tasks)
318+
return false
319+
end
320+
321+
# Update edge counters
322+
for (part, edges) in part_edges
323+
g.parts_ne[part] += length(edges)
324+
end
325+
for (part, edges) in back_edges
326+
g.bg_adjs_ne_src[part] += length(edges)
327+
#= FIXME
328+
g.bg_adjs_ne[src_part_idx] += 1
329+
g.bg_adjs_ne[dst_part_idx] += 1
330+
=#
331+
end
332+
333+
return true
334+
end
335+
function add_edges!(g::Graphs.AbstractSimpleGraph, shift, edges)
336+
for edge in edges
337+
src, dst = Tuple(edge)
338+
add_edge!(g, src-shift, dst-shift) || return false
339+
end
340+
return true
341+
end
283342
edge_owner(src::Int, dst::Int, src_part_idx::Int, dst_part_idx::Int) =
284343
iseven(hash(Base.unsafe_trunc(UInt, src+dst))) ? src_part_idx : dst_part_idx
285344
Graphs.inneighbors(g::DGraph, v::Integer) = with_state(g, inneighbors, v)

0 commit comments

Comments
 (0)