|
| 1 | +""" |
| 2 | + WithGraph(layer, fg, [subgraph=:]) |
| 3 | +
|
| 4 | +Train GNN layers with fixed graph. |
| 5 | +
|
| 6 | +# Arguments |
| 7 | +
|
| 8 | +- `layer`: A GNN layer. |
| 9 | +- `fg`: A fixed `FeaturedGraph` to train with. |
| 10 | +- `subgraph`: Node indeices to get a subgraph from `fg`. |
| 11 | +
|
| 12 | +# Example |
| 13 | +
|
| 14 | +```jldoctest |
| 15 | +julia> adj = [0 1 0 1; |
| 16 | + 1 0 1 0; |
| 17 | + 0 1 0 1; |
| 18 | + 1 0 1 0]; |
| 19 | +
|
| 20 | +julia> fg = FeaturedGraph(adj); |
| 21 | +
|
| 22 | +julia> gc = WithGraph(GCNConv(1024=>256), fg) |
| 23 | +WithGraph(GCNConv(1024 => 256), FeaturedGraph(#V=4, #E=4)) |
| 24 | +
|
| 25 | +julia> subgraph = [1, 2, 4] # specify subgraph nodes |
| 26 | +
|
| 27 | +julia> gc = WithGraph(GCNConv(1024=>256), fg, subgraph) |
| 28 | +WithGraph(GCNConv(1024 => 256), FeaturedGraph(#V=4, #E=4), subgraph=[1, 2, 4]) |
| 29 | +``` |
| 30 | +""" |
| 31 | +struct WithGraph{L,G<:AbstractFeaturedGraph,S} |
| 32 | + layer::L |
| 33 | + fg::G |
| 34 | + subgraph::S |
| 35 | +end |
| 36 | + |
| 37 | +@functor WithGraph |
| 38 | + |
| 39 | +Flux.trainable(l::WithGraph) = (l.layer, ) |
| 40 | + |
| 41 | +WithGraph(layer, fg::AbstractFeaturedGraph) = WithGraph(layer, fg, :) |
| 42 | + |
| 43 | +function Base.show(io::IO, l::WithGraph) |
| 44 | + print(io, "WithGraph(") |
| 45 | + print(io, l.layer, ", ") |
| 46 | + print(io, "FeaturedGraph(#V=", nv(l.fg), ", #E=", ne(l.fg), ")") |
| 47 | + l.subgraph == (:) || print(io, ", subgraph=", l.subgraph) |
| 48 | + print(io, ")") |
| 49 | +end |
| 50 | + |
| 51 | +""" |
| 52 | + GraphParallel(; node_layer=identity, edge_layer=identity, global_layer=identity) |
| 53 | +
|
| 54 | +Passing features in `FeaturedGraph` in parallel. It takes `FeaturedGraph` as input |
| 55 | +and it can be specified by assigning layers for specific (node, edge and global) features. |
| 56 | +
|
| 57 | +# Arguments |
| 58 | +
|
| 59 | +- `node_layer`: A regular Flux layer for passing node features. |
| 60 | +- `edge_layer`: A regular Flux layer for passing edge features. |
| 61 | +- `global_layer`: A regular Flux layer for passing global features. |
| 62 | +
|
| 63 | +# Example |
| 64 | +
|
| 65 | +```jldoctest |
| 66 | +julia> l = GraphParallel( |
| 67 | + node_layer=Dropout(0.5), |
| 68 | + global_layer=Dense(10, 5) |
| 69 | + ) |
| 70 | +``` |
| 71 | +""" |
| 72 | +struct GraphParallel{N,E,G} |
| 73 | + node_layer::N |
| 74 | + edge_layer::E |
| 75 | + global_layer::G |
| 76 | +end |
| 77 | + |
| 78 | +@functor GraphParallel |
| 79 | + |
| 80 | +GraphParallel(; node_layer=identity, edge_layer=identity, global_layer=identity) = |
| 81 | + GraphParallel(node_layer, edge_layer, global_layer) |
| 82 | + |
| 83 | +function (l::GraphParallel)(fg::FeaturedGraph) |
| 84 | + nf = l.node_layer(node_feature(fg)) |
| 85 | + ef = l.edge_layer(edge_feature(fg)) |
| 86 | + gf = l.global_layer(global_feature(fg)) |
| 87 | + return FeaturedGraph(fg, nf=nf, ef=ef, gf=gf) |
| 88 | +end |
| 89 | + |
| 90 | +function Base.show(io::IO, l::GraphParallel) |
| 91 | + print(io, "GraphParallel(") |
| 92 | + print(io, "node_layer=", l.node_layer) |
| 93 | + print(io, ", edge_layer=", l.edge_layer) |
| 94 | + print(io, ", global_layer=", l.global_layer) |
| 95 | + print(io, ")") |
| 96 | +end |
0 commit comments