@@ -19,7 +19,7 @@ julia> gc = GCNConv(1024=>256, relu)
1919GCNConv(1024 => 256, relu)
2020```
2121
22- See also [`WithGraph`](@ref) for training layer with fixed graph or subgraph .
22+ See also [`WithGraph`](@ref) for training layer with fixed graph.
2323"""
2424struct GCNConv{A<: AbstractMatrix ,B,F}
2525 weight:: A
3737
3838@functor GCNConv
3939
40- (l:: GCNConv )(Ã:: AbstractArray , x:: AbstractArray ) = l. σ .(l. weight * x * Ã .+ l. bias)
40+ (l:: GCNConv )(Ã:: AbstractMatrix , x:: AbstractMatrix ) = l. σ .(l. weight * x * Ã .+ l. bias)
4141
4242function (l:: GCNConv )(fg:: AbstractFeaturedGraph )
4343 nf = node_feature (fg)
4444 Ã = Zygote. ignore () do
4545 GraphSignals. normalized_adjacency_matrix (fg, eltype (nf); selfloop= true )
4646 end
47- return FeaturedGraph (fg, nf = l (Ã, nf))
47+ return ConcreteFeaturedGraph (fg, nf = l (Ã, nf))
4848end
4949
5050function (wg:: WithGraph{<:GCNConv} )(X:: AbstractArray )
51- N = size (X, 2 )
52- wg. subgraph != (:) && N != length (wg. subgraph) &&
53- throw (ArgumentError (" Layer with subgraph expecting subset of features, got #V=$N but #V for subgraph $(length (wg. subgraph)) ." ))
5451 Ã = Zygote. ignore () do
5552 GraphSignals. normalized_adjacency_matrix (wg. fg, eltype (X); selfloop= true )
5653 end
57- return wg. layer (Ã[wg . subgraph, wg . subgraph] , X)
54+ return wg. layer (Ã, X)
5855end
5956
6057function Base. show (io:: IO , l:: GCNConv )
@@ -101,7 +98,7 @@ ChebConv(ch::Pair{Int,Int}, k::Int; kwargs...) =
10198
10299Flux. trainable (l:: ChebConv ) = (l. weight, l. bias)
103100
104- function (c:: ChebConv )(fg:: ConcreteFeaturedGraph , X:: AbstractMatrix{T} ) where T
101+ function (c:: ChebConv )(fg:: AbstractFeaturedGraph , X:: AbstractMatrix{T} ) where T
105102 GraphSignals. check_num_nodes (fg, X)
106103 @assert size (X, 1 ) == size (c. weight, 2 ) " Input feature size must match input channel size."
107104
@@ -175,7 +172,7 @@ message(gc::GraphConv, x_i, x_j::AbstractVector, e_ij) = gc.weight2 * x_j
175172
176173update (gc:: GraphConv , m:: AbstractVector , x:: AbstractVector ) = gc. σ .(gc. weight1* x .+ m .+ gc. bias)
177174
178- function (gc:: GraphConv )(fg:: ConcreteFeaturedGraph , x:: AbstractMatrix )
175+ function (gc:: GraphConv )(fg:: AbstractFeaturedGraph , x:: AbstractMatrix )
179176 # GraphSignals.check_num_nodes(fg, x)
180177 _, x, _ = propagate (gc, fg, edge_feature (fg), x, global_feature (fg), + )
181178 x
@@ -290,7 +287,7 @@ function update_batch_vertex(gat::GATConv, ::AbstractFeaturedGraph, M::AbstractM
290287 return M
291288end
292289
293- function (gat:: GATConv )(fg:: ConcreteFeaturedGraph , X:: AbstractMatrix )
290+ function (gat:: GATConv )(fg:: AbstractFeaturedGraph , X:: AbstractMatrix )
294291 GraphSignals. check_num_nodes (fg, X)
295292 _, X, _ = propagate (gat, fg, edge_feature (fg), X, global_feature (fg), + )
296293 return X
@@ -349,7 +346,7 @@ message(ggc::GatedGraphConv, x_i, x_j::AbstractVector, e_ij) = x_j
349346update (ggc:: GatedGraphConv , m:: AbstractVector , x) = m
350347
351348
352- function (ggc:: GatedGraphConv )(fg:: ConcreteFeaturedGraph , H:: AbstractMatrix{S} ) where {T<: AbstractVector ,S<: Real }
349+ function (ggc:: GatedGraphConv )(fg:: AbstractFeaturedGraph , H:: AbstractMatrix{S} ) where {T<: AbstractVector ,S<: Real }
353350 GraphSignals. check_num_nodes (fg, H)
354351 m, n = size (H)
355352 @assert (m <= ggc. out_ch) " number of input features must less or equals to output features."
@@ -406,7 +403,7 @@ Flux.trainable(l::EdgeConv) = (l.nn,)
406403message (ec:: EdgeConv , x_i:: AbstractVector , x_j:: AbstractVector , e_ij) = ec. nn (vcat (x_i, x_j .- x_i))
407404update (ec:: EdgeConv , m:: AbstractVector , x) = m
408405
409- function (ec:: EdgeConv )(fg:: ConcreteFeaturedGraph , X:: AbstractMatrix )
406+ function (ec:: EdgeConv )(fg:: AbstractFeaturedGraph , X:: AbstractMatrix )
410407 GraphSignals. check_num_nodes (fg, X)
411408 _, X, _ = propagate (ec, fg, edge_feature (fg), X, global_feature (fg), ec. aggr)
412409 X
@@ -457,7 +454,7 @@ Flux.trainable(g::GINConv) = (fg=g.fg, nn=g.nn)
457454message (g:: GINConv , x_i:: AbstractVector , x_j:: AbstractVector ) = x_j
458455update (g:: GINConv , m:: AbstractVector , x) = g. nn ((1 + g. eps) * x + m)
459456
460- function (g:: GINConv )(fg:: ConcreteFeaturedGraph , X:: AbstractMatrix )
457+ function (g:: GINConv )(fg:: AbstractFeaturedGraph , X:: AbstractMatrix )
461458 gf = graph (fg)
462459 GraphSignals. check_num_nodes (gf, X)
463460 _, X, _ = propagate (g, fg, edge_feature (fg), X, global_feature (fg), + )
@@ -526,7 +523,7 @@ message(c::CGConv,
526523end
527524update (c:: CGConv , m:: AbstractVector , x) = x + m
528525
529- function (c:: CGConv )(fg:: ConcreteFeaturedGraph , X:: AbstractMatrix , E:: AbstractMatrix )
526+ function (c:: CGConv )(fg:: AbstractFeaturedGraph , X:: AbstractMatrix , E:: AbstractMatrix )
530527 GraphSignals. check_num_nodes (fg, X)
531528 GraphSignals. check_num_edges (fg, E)
532529 _, Y, _ = propagate (c, fg, E, X, global_feature (fg), + )
0 commit comments