@@ -87,26 +87,32 @@ function DGraph(dg::DGraph{T,D}; chunksize::Integer=0, directed::Bool=D, freeze:
8787 freeze && freeze! (g)
8888 return g
8989end
90- function with_state (g:: DGraph , f, args... )
90+ function with_state (g:: DGraph , f, args... ; kwargs ... )
9191 if g. frozen[]
9292 @assert ! any (x-> x isa ELTYPE, args)
93- return f (g. state, args... )
93+ return f (g. state, args... ; kwargs ... )
9494 else
95- return fetch (Dagger. @spawn f (g. state, args... ))
95+ return fetch (Dagger. @spawn f (g. state, args... ; kwargs ... ))
9696 end
9797end
98- function exec_fast (f, args... ; fetch :: Bool = true )
98+ function exec_fast (f, args... ; kwargs ... )
9999 # FIXME : Ensure that `EagerThunk` result is also local
100100 if any (x-> (x isa Dagger. EagerThunk && ! isready (x)) ||
101101 (x isa Dagger. Chunk && x. handle. owner != myid ()), args)
102- if fetch
103- return Base. fetch (Dagger. @spawn f (args... ))
104- else
105- return Dagger. @spawn f (args... )
106- end
102+ return Base. fetch (Dagger. @spawn f (args... ; kwargs... ))
103+ else
104+ fetched_args = ntuple (i-> args[i] isa ELTYPE ? Base. fetch (args[i]) : args[i], length (args))
105+ return f (fetched_args... ; kwargs... )
106+ end
107+ end
108+ function exec_fast_nofetch (f, args... ; kwargs... )
109+ # FIXME : Ensure that `EagerThunk` result is also local
110+ if any (x-> (x isa Dagger. EagerThunk && ! isready (x)) ||
111+ (x isa Dagger. Chunk && x. handle. owner != myid ()), args)
112+ return Dagger. @spawn f (args... ; kwargs... )
107113 else
108114 fetched_args = ntuple (i-> args[i] isa ELTYPE ? Base. fetch (args[i]) : args[i], length (args))
109- return f (fetched_args... )
115+ return f (fetched_args... ; kwargs ... )
110116 end
111117end
112118
@@ -172,7 +178,7 @@ function Graphs.has_edge(g::DGraphState{T,D}, src::Integer, dst::Integer) where
172178 else
173179 # The edge will be in an AdjList
174180 adj = g. bg_adjs[src_part_idx]
175- return exec_fast (Base . in , adj, ( src, dst) )
181+ return exec_fast (has_edge , adj, src, dst)
176182 end
177183end
178184Graphs. is_directed (:: DGraph{T,D} ) where {T,D} = D
@@ -235,12 +241,13 @@ function add_partition!(g::DGraph, sg::AbstractGraph)
235241 check_not_frozen (g)
236242 return with_state (g, add_partition!, sg)
237243end
238- function add_partition! (g:: DGraphState{T,D} , sg:: AbstractGraph ) where {T,D}
244+ function add_partition! (g:: DGraphState{T,D} , sg:: AbstractGraph ; all :: Bool = true ) where {T,D}
239245 check_not_frozen (g)
240246 shift = nv (g)
241247 part = add_partition! (g, nv (sg))
242248 part_edges = map (edge-> (src (edge)+ shift, dst (edge)+ shift), collect (edges (sg)))
243- @assert add_edges! (g, part_edges)
249+ count = add_edges! (g, part_edges; all)
250+ @assert ! all || count == length (part_edges)
244251 return part
245252end
246253function Graphs. add_edge! (g:: DGraph , src:: Integer , dst:: Integer )
@@ -292,17 +299,19 @@ function Graphs.add_edge!(g::DGraphState{T,D}, src::Integer, dst::Integer) where
292299
293300 return true
294301end
295- function add_edges! (g:: DGraph , iter)
302+ function add_edges! (g:: DGraph , iter; all :: Bool = true )
296303 check_not_frozen (g)
297- return with_state (g, add_edges!, iter)
304+ return with_state (g, add_edges!, iter; all )
298305end
299- function add_edges! (g:: DGraphState{T,D} , iter) where {T,D}
306+ function add_edges! (g:: DGraphState{T,D} , iter; all :: Bool = true ) where {T,D}
300307 check_not_frozen (g)
301308
302309 # Determine edge partition/background
303310 part_edges = Dict {Int,Vector{Tuple{T,T}}} (part=> Tuple{T,T}[] for part in 1 : nparts (g))
304311 back_edges = Dict {Int,Vector{Tuple{T,T}}} (part=> Tuple{T,T}[] for part in 1 : nparts (g))
312+ nedges = 0
305313 for edge in iter
314+ nedges += 1
306315 src, dst = Tuple (edge)
307316
308317 src_part_idx = findfirst (span-> src in span, g. parts_nv)
@@ -320,34 +329,32 @@ function add_edges!(g::DGraphState{T,D}, iter) where {T,D}
320329 end
321330
322331 # Add edges concurrently
323- part_tasks = [exec_fast (add_edges!, g. parts[part], g. parts_nv[part]. start- 1 , edges; fetch= false ) for (part, edges) in part_edges]
324- back_tasks = [exec_fast (add_edges!, g. bg_adjs[part], edges; fetch= false ) for (part, edges) in back_edges]
325-
326- # Validate that all edges were successfully added
327- if ! all (fetch, part_tasks) || ! all (fetch, back_tasks)
328- return false
329- end
332+ part_tasks = Dict (part=> exec_fast_nofetch (add_edges!, g. parts[part], g. parts_nv[part]. start- 1 , edges; all) for (part, edges) in part_edges)
333+ back_tasks = Dict (part=> exec_fast_nofetch (add_edges!, g. bg_adjs[part], edges; all) for (part, edges) in back_edges)
330334
331335 # Update edge counters
332- for (part, edges ) in part_edges
333- g. parts_ne[part] += length (edges )
336+ for (part, edge_count ) in part_tasks
337+ g. parts_ne[part] += fetch (edge_count )
334338 end
335- for (part, edges) in back_edges
336- g. bg_adjs_ne_src[part] += length (edges)
337- #= FIXME
338- g.bg_adjs_ne[src_part_idx] += 1
339- g.bg_adjs_ne[dst_part_idx] += 1
340- =#
339+ for (part, edge_count) in back_tasks
340+ g. bg_adjs_ne_src[part] += fetch (edge_count)
341+ g. bg_adjs_ne[part] = exec_fast (ne, g. bg_adjs[part])
341342 end
342343
343- return true
344+ # Validate that all edges were successfully added
345+ return sum (fetch, values (part_tasks)) + sum (fetch, values (back_tasks))
344346end
345- function add_edges! (g:: Graphs.AbstractSimpleGraph , shift, edges)
347+ function add_edges! (g:: Graphs.AbstractSimpleGraph , shift, edges; all:: Bool = true )
348+ count = 0
346349 for edge in edges
347350 src, dst = Tuple (edge)
348- add_edge! (g, src- shift, dst- shift) || return false
351+ if add_edge! (g, src- shift, dst- shift)
352+ count += 1
353+ elseif all
354+ return count
355+ end
349356 end
350- return true
357+ return count
351358end
352359edge_owner (src:: Int , dst:: Int , src_part_idx:: Int , dst_part_idx:: Int ) =
353360 iseven (hash (Base. unsafe_trunc (UInt, src+ dst))) ? src_part_idx : dst_part_idx
0 commit comments