11_view (:: Nothing , idx) = nothing
22_view (A:: Fill{T,2,Axes} , idx) where {T,Axes} = fill (A. value, A. axes[1 ], length (idx))
3-
4- function _view (A:: SubArray{T,2,S} , idx) where {T,S<: Fill }
5- p = parent (A)
6- return Fill (p. value, p. axes[1 ]. stop, length (idx))
7- end
8-
93_view (A:: AbstractMatrix , idx) = view (A, :, idx)
104
11- function _view (A:: SubArray{T,2,S} , idxs) where {T,S<: AbstractMatrix }
12- view_idx = A. indices[2 ]
13- if view_idx == idxs
14- return A
15- else
16- idxs = findall (x -> x in idxs, view_idx)
17- return view (A, :, idxs)
18- end
19- end
20-
215aggregate (aggr:: typeof (+ ), X) = vec (sum (X, dims= 2 ))
226aggregate (aggr:: typeof (- ), X) = - vec (sum (X, dims= 2 ))
237aggregate (aggr:: typeof (* ), X) = vec (prod (X, dims= 2 ))
@@ -32,10 +16,16 @@ abstract type GraphNet <: AbstractGraphLayer end
3216@inline update_vertex (gn:: GraphNet , ē, vi, u) = vi
3317@inline update_global (gn:: GraphNet , ē, v̄, u) = u
3418
19+ function _get_indices (fg:: AbstractFeaturedGraph )
20+ es = cpu (GraphSignals. incident_edges (fg))
21+ xs = cpu (GraphSignals. repeat_nodes (fg))
22+ nbrs = cpu (GraphSignals. neighbors (fg))
23+ sorted_idx = sort! (collect (zip (es, xs, nbrs)), by= x-> x[1 ])
24+ return collect .(collect (zip (sorted_idx... )))
25+ end
26+
3527@inline function update_batch_edge (gn:: GraphNet , fg:: AbstractFeaturedGraph , E, V, u)
36- es = Zygote. ignore (()-> cpu (GraphSignals. incident_edges (fg)))
37- xs = Zygote. ignore (()-> cpu (GraphSignals. repeat_nodes (fg)))
38- nbrs = Zygote. ignore (()-> cpu (GraphSignals. neighbors (fg)))
28+ es, xs, nbrs = Zygote. ignore (()-> _get_indices (fg))
3929 ms = map ((e,i,j)-> update_edge (gn, _view (E, e), _view (V, i), _view (V, j), u), es, xs, nbrs)
4030 M = hcat_by_sum (ms)
4131 return M
5040
5141@inline function aggregate_neighbors (gn:: GraphNet , fg:: AbstractFeaturedGraph , aggr, E)
5242 N = nv (parent (fg))
53- xs = Zygote. ignore (()-> cpu (GraphSignals . repeat_nodes (fg) ))
43+ es, xs, nbrs = Zygote. ignore (()-> _get_indices (fg ))
5444 Ē = NNlib. scatter (aggr, E, xs; dstsize= (size (E, 1 ), N))
5545 return Ē
5646end
@@ -67,6 +57,14 @@ function propagate(gn::GraphNet, fg::AbstractFeaturedGraph, naggr=nothing, eaggr
6757 FeaturedGraph (fg, nf= V, ef= E, gf= u)
6858end
6959
60+ """
61+ - `update_batch_edge`: (E_in_dim, E) -> (E_out_dim, E)
62+ - `aggregate_neighbors`: (E_out_dim, E) -> (E_out_dim, V)
63+ - `update_batch_vertex`: (V_in_dim, V) -> (V_out_dim, V)
64+ - `aggregate_edges`: (E_out_dim, E) -> (E_out_dim,)
65+ - `aggregate_vertices`: (V_out_dim, V) -> (V_out_dim,)
66+ - `update_global`: (dim,) -> (dim,)
67+ """
7068function propagate (gn:: GraphNet , fg:: AbstractFeaturedGraph , E:: AbstractArray , V:: AbstractArray , u:: AbstractArray ,
7169 naggr= nothing , eaggr= nothing , vaggr= nothing )
7270 E = update_batch_edge (gn, fg, E, V, u)
@@ -75,5 +73,5 @@ function propagate(gn::GraphNet, fg::AbstractFeaturedGraph, E::AbstractArray, V:
7573 ē = aggregate_edges (gn, eaggr, E)
7674 v̄ = aggregate_vertices (gn, vaggr, V)
7775 u = update_global (gn, ē, v̄, u)
78- return parent (E), parent (V) , u
76+ return E, V , u
7977end
0 commit comments