@@ -38,14 +38,14 @@ GCNConv(ch::Pair{Int,Int}, σ = identity; kwargs...) =
3838
3939Flux. trainable (l:: GCNConv ) = (l. weight, l. bias)
4040
41- function (l:: GCNConv )(fg:: FeaturedGraph , x:: AbstractMatrix )
41+ function (l:: GCNConv )(fg:: ConcreteFeaturedGraph , x:: AbstractMatrix )
4242 Ã = Zygote. ignore () do
4343 GraphSignals. normalized_adjacency_matrix (fg, eltype (x); selfloop= true )
4444 end
4545 l. σ .(l. weight * x * Ã .+ l. bias)
4646end
4747
48- (l:: GCNConv )(fg:: FeaturedGraph ) = FeaturedGraph (fg, nf = l (fg, node_feature (fg)))
48+ (l:: GCNConv )(fg:: AbstractFeaturedGraph ) = FeaturedGraph (fg, nf = l (fg, node_feature (fg)))
4949
5050function Base. show (io:: IO , l:: GCNConv )
5151 out, in = size (l. weight)
@@ -91,7 +91,7 @@ ChebConv(ch::Pair{Int,Int}, k::Int; kwargs...) =
9191
9292Flux. trainable (l:: ChebConv ) = (l. weight, l. bias)
9393
94- function (c:: ChebConv )(fg:: FeaturedGraph , X:: AbstractMatrix{T} ) where T
94+ function (c:: ChebConv )(fg:: ConcreteFeaturedGraph , X:: AbstractMatrix{T} ) where T
9595 GraphSignals. check_num_nodes (fg, X)
9696 @assert size (X, 1 ) == size (c. weight, 2 ) " Input feature size must match input channel size."
9797
@@ -110,7 +110,7 @@ function (c::ChebConv)(fg::FeaturedGraph, X::AbstractMatrix{T}) where T
110110 return Y .+ c. bias
111111end
112112
113- (l:: ChebConv )(fg:: FeaturedGraph ) = FeaturedGraph (fg, nf = l (fg, node_feature (fg)))
113+ (l:: ChebConv )(fg:: AbstractFeaturedGraph ) = FeaturedGraph (fg, nf = l (fg, node_feature (fg)))
114114
115115function Base. show (io:: IO , l:: ChebConv )
116116 out, in, k = size (l. weight)
@@ -165,14 +165,14 @@ message(gc::GraphConv, x_i, x_j::AbstractVector, e_ij) = gc.weight2 * x_j
165165
166166update (gc:: GraphConv , m:: AbstractVector , x:: AbstractVector ) = gc. σ .(gc. weight1* x .+ m .+ gc. bias)
167167
168- function (gc:: GraphConv )(fg:: FeaturedGraph , x:: AbstractMatrix )
169- GraphSignals. check_num_nodes (fg, x)
170- _, x, _ = propagate (gc, graph (fg) , edge_feature (fg), x, global_feature (fg), + )
168+ function (gc:: GraphConv )(fg:: ConcreteFeaturedGraph , x:: AbstractMatrix )
169+ # GraphSignals.check_num_nodes(fg, x)
170+ _, x, _ = propagate (gc, fg , edge_feature (fg), x, global_feature (fg), + )
171171 x
172172end
173173
174- (l:: GraphConv )(fg:: FeaturedGraph ) = FeaturedGraph (fg, nf = l (fg, node_feature (fg)))
175- # (l::GraphConv)(fg::FeaturedGraph ) = propagate(l, fg, +) # edge number check break this
174+ (l:: GraphConv )(fg:: AbstractFeaturedGraph ) = FeaturedGraph (fg, nf = l (fg, node_feature (fg)))
175+ # (l::GraphConv)(fg::AbstractFeaturedGraph ) = propagate(l, fg, +) # edge number check break this
176176
177177function Base. show (io:: IO , l:: GraphConv )
178178 in_channel = size (l. weight1, ndims (l. weight1))
@@ -244,18 +244,22 @@ end
244244
245245# After some reshaping due to the multihead, we get the α from each message,
246246# then get the softmax over every α, and eventually multiply the message by α
247- function apply_batch_message (gat:: GATConv , i, js, X:: AbstractMatrix )
248- e_ij = mapreduce (j -> GeometricFlux. message (gat, _view (X, i), _view (X, j)), hcat, js)
249- n = size (e_ij, 1 )
250- αs = Flux. softmax (reshape (view (e_ij, 1 , :), gat. heads, :), dims= 2 )
251- msgs = view (e_ij, 2 : n, :) .* reshape (αs, 1 , :)
252- reshape (msgs, (n- 1 )* gat. heads, :)
253- end
254-
255- function update_batch_edge (gat:: GATConv , sg:: SparseGraph , E:: AbstractMatrix , X:: AbstractMatrix , u)
256- @assert Zygote. ignore (() -> check_self_loops (sg)) " a vertex must have self loop (receive a message from itself)."
257- ys = map (i -> apply_batch_message (gat, i, GraphSignals. cpu_neighbors (sg, i), X), 1 : nv (sg))
258- return hcat (ys... )
247+ function graph_attention (gat:: GATConv , i, js, X:: AbstractMatrix )
248+ e_ij = map (j -> GeometricFlux. message (gat, _view (X, i), _view (X, j)), js)
249+ E = hcat_by_sum (e_ij)
250+ n = size (E, 1 )
251+ αs = Flux. softmax (reshape (view (E, 1 , :), gat. heads, :), dims= 2 )
252+ msgs = view (E, 2 : n, :) .* reshape (αs, 1 , :)
253+ return reshape (msgs, (n- 1 )* gat. heads, :)
254+ end
255+
256+ function update_batch_edge (gat:: GATConv , fg:: AbstractFeaturedGraph , E:: AbstractMatrix , X:: AbstractMatrix , u)
257+ @assert Zygote. ignore (() -> check_self_loops (graph (fg))) " a vertex must have self loop (receive a message from itself)."
258+ nodes = Zygote. ignore (()-> vertices (fg))
259+ nbr = i-> cpu (GraphSignals. neighbors (graph (fg), i))
260+ ms = map (i -> graph_attention (gat, i, Zygote. ignore (()-> nbr (i)), X), nodes)
261+ M = hcat_by_sum (ms)
262+ return M
259263end
260264
261265function check_self_loops (sg:: SparseGraph )
@@ -267,7 +271,7 @@ function check_self_loops(sg::SparseGraph)
267271 return true
268272end
269273
270- function update_batch_vertex (gat:: GATConv , M:: AbstractMatrix , X:: AbstractMatrix , u)
274+ function update_batch_vertex (gat:: GATConv , :: AbstractFeaturedGraph , M:: AbstractMatrix , X:: AbstractMatrix , u)
271275 M = M .+ gat. bias
272276 if ! gat. concat
273277 N = size (M, 2 )
@@ -276,14 +280,14 @@ function update_batch_vertex(gat::GATConv, M::AbstractMatrix, X::AbstractMatrix,
276280 return M
277281end
278282
279- function (gat:: GATConv )(fg:: FeaturedGraph , X:: AbstractMatrix )
283+ function (gat:: GATConv )(fg:: ConcreteFeaturedGraph , X:: AbstractMatrix )
280284 GraphSignals. check_num_nodes (fg, X)
281- _, X, _ = propagate (gat, graph (fg) , edge_feature (fg), X, global_feature (fg), + )
285+ _, X, _ = propagate (gat, fg , edge_feature (fg), X, global_feature (fg), + )
282286 return X
283287end
284288
285- (l:: GATConv )(fg:: FeaturedGraph ) = FeaturedGraph (fg, nf = l (fg, node_feature (fg)))
286- # (l::GATConv)(fg::FeaturedGraph ) = propagate(l, fg, +) # edge number check break this
289+ (l:: GATConv )(fg:: AbstractFeaturedGraph ) = FeaturedGraph (fg, nf = l (fg, node_feature (fg)))
290+ # (l::GATConv)(fg::AbstractFeaturedGraph ) = propagate(l, fg, +) # edge number check break this
287291
288292function Base. show (io:: IO , l:: GATConv )
289293 in_channel = size (l. weight, ndims (l. weight))
@@ -335,7 +339,7 @@ message(ggc::GatedGraphConv, x_i, x_j::AbstractVector, e_ij) = x_j
335339update (ggc:: GatedGraphConv , m:: AbstractVector , x) = m
336340
337341
338- function (ggc:: GatedGraphConv )(fg:: FeaturedGraph , H:: AbstractMatrix{S} ) where {T<: AbstractVector ,S<: Real }
342+ function (ggc:: GatedGraphConv )(fg:: ConcreteFeaturedGraph , H:: AbstractMatrix{S} ) where {T<: AbstractVector ,S<: Real }
339343 GraphSignals. check_num_nodes (fg, H)
340344 m, n = size (H)
341345 @assert (m <= ggc. out_ch) " number of input features must less or equals to output features."
@@ -347,14 +351,14 @@ function (ggc::GatedGraphConv)(fg::FeaturedGraph, H::AbstractMatrix{S}) where {T
347351 end
348352 for i = 1 : ggc. num_layers
349353 M = view (ggc. weight, :, :, i) * H
350- _, M = propagate (ggc, graph (fg) , edge_feature (fg), M, global_feature (fg), + )
354+ _, M = propagate (ggc, fg , edge_feature (fg), M, global_feature (fg), + )
351355 H, _ = ggc. gru (H, M)
352356 end
353357 H
354358end
355359
356- (l:: GatedGraphConv )(fg:: FeaturedGraph ) = FeaturedGraph (fg, nf = l (fg, node_feature (fg)))
357- # (l::GatedGraphConv)(fg::FeaturedGraph ) = propagate(l, fg, +) # edge number check break this
360+ (l:: GatedGraphConv )(fg:: AbstractFeaturedGraph ) = FeaturedGraph (fg, nf = l (fg, node_feature (fg)))
361+ # (l::GatedGraphConv)(fg::AbstractFeaturedGraph ) = propagate(l, fg, +) # edge number check break this
358362
359363
360364function Base. show (io:: IO , l:: GatedGraphConv )
@@ -392,14 +396,14 @@ Flux.trainable(l::EdgeConv) = (l.nn,)
392396message (ec:: EdgeConv , x_i:: AbstractVector , x_j:: AbstractVector , e_ij) = ec. nn (vcat (x_i, x_j .- x_i))
393397update (ec:: EdgeConv , m:: AbstractVector , x) = m
394398
395- function (ec:: EdgeConv )(fg:: FeaturedGraph , X:: AbstractMatrix )
399+ function (ec:: EdgeConv )(fg:: ConcreteFeaturedGraph , X:: AbstractMatrix )
396400 GraphSignals. check_num_nodes (fg, X)
397- _, X, _ = propagate (ec, graph (fg) , edge_feature (fg), X, global_feature (fg), ec. aggr)
401+ _, X, _ = propagate (ec, fg , edge_feature (fg), X, global_feature (fg), ec. aggr)
398402 X
399403end
400404
401- (l:: EdgeConv )(fg:: FeaturedGraph ) = FeaturedGraph (fg, nf = l (fg, node_feature (fg)))
402- # (l::EdgeConv)(fg::FeaturedGraph ) = propagate(l, fg, l.aggr) # edge number check break this
405+ (l:: EdgeConv )(fg:: AbstractFeaturedGraph ) = FeaturedGraph (fg, nf = l (fg, node_feature (fg)))
406+ # (l::EdgeConv)(fg::AbstractFeaturedGraph ) = propagate(l, fg, l.aggr) # edge number check break this
403407
404408function Base. show (io:: IO , l:: EdgeConv )
405409 print (io, " EdgeConv(" , l. nn)
@@ -443,15 +447,15 @@ Flux.trainable(g::GINConv) = (fg=g.fg, nn=g.nn)
443447message (g:: GINConv , x_i:: AbstractVector , x_j:: AbstractVector ) = x_j
444448update (g:: GINConv , m:: AbstractVector , x) = g. nn ((1 + g. eps) * x + m)
445449
446- function (g:: GINConv )(fg:: FeaturedGraph , X:: AbstractMatrix )
450+ function (g:: GINConv )(fg:: ConcreteFeaturedGraph , X:: AbstractMatrix )
447451 gf = graph (fg)
448452 GraphSignals. check_num_nodes (gf, X)
449- _, X, _ = propagate (g, graph (fg) , edge_feature (fg), X, global_feature (fg), + )
453+ _, X, _ = propagate (g, fg , edge_feature (fg), X, global_feature (fg), + )
450454 X
451455end
452456
453- (l:: GINConv )(fg:: FeaturedGraph ) = FeaturedGraph (fg, nf = l (fg, node_feature (fg)))
454- # (l::GINConv)(fg::FeaturedGraph ) = propagate(l, fg, +) # edge number check break this
457+ (l:: GINConv )(fg:: AbstractFeaturedGraph ) = FeaturedGraph (fg, nf = l (fg, node_feature (fg)))
458+ # (l::GINConv)(fg::AbstractFeaturedGraph ) = propagate(l, fg, +) # edge number check break this
455459
456460
457461"""
@@ -512,17 +516,17 @@ message(c::CGConv,
512516end
513517update (c:: CGConv , m:: AbstractVector , x) = x + m
514518
515- function (c:: CGConv )(fg:: FeaturedGraph , X:: AbstractMatrix , E:: AbstractMatrix )
519+ function (c:: CGConv )(fg:: ConcreteFeaturedGraph , X:: AbstractMatrix , E:: AbstractMatrix )
516520 GraphSignals. check_num_nodes (fg, X)
517521 GraphSignals. check_num_edges (fg, E)
518- _, Y, _ = propagate (c, graph (fg) , E, X, global_feature (fg), + )
522+ _, Y, _ = propagate (c, fg , E, X, global_feature (fg), + )
519523 Y
520524end
521525
522- (l:: CGConv )(fg:: FeaturedGraph ) = FeaturedGraph (fg,
526+ (l:: CGConv )(fg:: AbstractFeaturedGraph ) = FeaturedGraph (fg,
523527 nf= l (fg, node_feature (fg), edge_feature (fg)),
524528 ef= edge_feature (fg))
525- # (l::CGConv)(fg::FeaturedGraph ) = propagate(l, fg, +) # edge number check break this
529+ # (l::CGConv)(fg::AbstractFeaturedGraph ) = propagate(l, fg, +) # edge number check break this
526530
527531(l:: CGConv )(X:: AbstractMatrix , E:: AbstractMatrix ) = l (l. fg, X, E)
528532
0 commit comments