@@ -36,6 +36,8 @@ GCNConv(ch::Pair{Int,Int}, σ = identity; kwargs...) =
3636
3737@functor GCNConv
3838
39+ Flux. trainable (l:: GCNConv ) = (l. weight, l. bias)
40+
3941function (l:: GCNConv )(fg:: FeaturedGraph , x:: AbstractMatrix )
4042 Ã = Zygote. ignore () do
4143 GraphSignals. normalized_adjacency_matrix (fg, eltype (x); selfloop= true )
@@ -87,6 +89,8 @@ ChebConv(ch::Pair{Int,Int}, k::Int; kwargs...) =
8789
8890@functor ChebConv
8991
92+ Flux. trainable (l:: ChebConv ) = (l. weight, l. bias)
93+
9094function (c:: ChebConv )(fg:: FeaturedGraph , X:: AbstractMatrix{T} ) where T
9195 GraphSignals. check_num_nodes (fg, X)
9296 @assert size (X, 1 ) == size (c. weight, 2 ) " Input feature size must match input channel size."
@@ -155,6 +159,8 @@ GraphConv(ch::Pair{Int,Int}, σ=identity, aggr=+; kwargs...) =
155159
156160@functor GraphConv
157161
162+ Flux. trainable (l:: GraphConv ) = (l. weight1, l. weight2, l. bias)
163+
158164message (gc:: GraphConv , x_i, x_j:: AbstractVector , e_ij) = gc. weight2 * x_j
159165
160166update (gc:: GraphConv , m:: AbstractVector , x:: AbstractVector ) = gc. σ .(gc. weight1* x .+ m .+ gc. bias)
@@ -224,6 +230,8 @@ GATConv(ch::Pair{Int,Int}; kwargs...) = GATConv(NullGraph(), ch; kwargs...)
224230
225231@functor GATConv
226232
233+ Flux. trainable (l:: GATConv ) = (l. weight, l. bias, l. a)
234+
227235# Here the α that has not been softmaxed is the first number of the output message
228236function message (gat:: GATConv , x_i:: AbstractVector , x_j:: AbstractVector )
229237 x_i = reshape (gat. weight* x_i, :, gat. heads)
@@ -319,6 +327,8 @@ GatedGraphConv(out_ch::Int, num_layers::Int; kwargs...) =
319327
320328@functor GatedGraphConv
321329
330+ Flux. trainable (l:: GatedGraphConv ) = (l. weight, l. gru)
331+
322332message (ggc:: GatedGraphConv , x_i, x_j:: AbstractVector , e_ij) = x_j
323333
324334update (ggc:: GatedGraphConv , m:: AbstractVector , x) = m
@@ -376,6 +386,8 @@ EdgeConv(nn; kwargs...) = EdgeConv(NullGraph(), nn; kwargs...)
376386
377387@functor EdgeConv
378388
389+ Flux. trainable (l:: EdgeConv ) = (l. nn,)
390+
379391message (ec:: EdgeConv , x_i:: AbstractVector , x_j:: AbstractVector , e_ij) = ec. nn (vcat (x_i, x_j .- x_i))
380392update (ec:: EdgeConv , m:: AbstractVector , x) = m
381393
@@ -423,13 +435,13 @@ function GINConv(nn, eps::Real=0f0)
423435 GINConv (NullGraph (), nn, eps)
424436end
425437
438+ @functor GINConv
439+
426440Flux. trainable (g:: GINConv ) = (fg= g. fg, nn= g. nn)
427441
428442message (g:: GINConv , x_i:: AbstractVector , x_j:: AbstractVector ) = x_j
429443update (g:: GINConv , m:: AbstractVector , x) = g. nn ((1 + g. eps) * x + m)
430444
431- @functor GINConv
432-
433445function (g:: GINConv )(fg:: FeaturedGraph , X:: AbstractMatrix )
434446 gf = graph (fg)
435447 GraphSignals. check_num_nodes (gf, X)
474486
475487@functor CGConv
476488
489+ Flux. trainable (l:: CGConv ) = (l. Wf, l. Ws, l. bf, l. bs)
490+
477491function CGConv (fg:: G , dims:: NTuple{2,Int} ;
478492 init= glorot_uniform, bias= true , as_edge= false ) where {G<: AbstractFeaturedGraph }
479493 node_dim, edge_dim = dims
0 commit comments