Skip to content

Commit 0e95727

Browse files
committed
implement new GCNConv
1 parent 995b4d8 commit 0e95727

File tree

10 files changed

+214
-122
lines changed

10 files changed

+214
-122
lines changed

Project.toml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,13 +27,13 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
2727
CUDA = "3"
2828
ChainRulesCore = "1.7"
2929
DataStructures = "0.18"
30-
FillArrays = "0.12"
30+
FillArrays = "0.12 - 0.13"
3131
Flux = "0.12"
3232
GraphMLDatasets = "0.1"
3333
GraphSignals = "0.3"
34-
Graphs = "1.4"
35-
NNlib = "0.7"
36-
NNlibCUDA = "0.1"
34+
Graphs = "1"
35+
NNlib = "0.7 - 0.8"
36+
NNlibCUDA = "0.1 - 0.2"
3737
Reexport = "1.1"
3838
Word2Vec = "0.5"
3939
Zygote = "0.6"

src/GeometricFlux.jl

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -54,8 +54,9 @@ export
5454
InnerProductDecoder,
5555
VariationalEncoder,
5656

57-
# layer/misc
58-
Bypass,
57+
# layer/utils
58+
WithGraph,
59+
GraphParallel,
5960

6061
#node2vec
6162
node2vec
@@ -68,10 +69,10 @@ include("layers/graphlayers.jl")
6869
include("layers/gn.jl")
6970
include("layers/msgpass.jl")
7071

72+
include("layers/utils.jl")
7173
include("layers/conv.jl")
7274
include("layers/pool.jl")
7375
include("models.jl")
74-
include("layers/misc.jl")
7576

7677
include("sampling.jl")
7778
include("embedding/node2vec.jl")

src/layers/conv.jl

Lines changed: 27 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,51 +1,61 @@
11
"""
2-
GCNConv([fg,] in => out, σ=identity; bias=true, init=glorot_uniform)
2+
GCNConv(in => out, σ=identity; bias=true, init=glorot_uniform)
33
4-
Graph convolutional layer.
4+
Graph convolutional layer. The input to the layer is a node feature array `X`
5+
of size `(num_features, num_nodes)`.
56
67
# Arguments
78
8-
- `fg`: Optionally pass a [`FeaturedGraph`](@ref).
99
- `in`: The dimension of input features.
1010
- `out`: The dimension of output features.
1111
- `σ`: Activation function.
1212
- `bias`: Add learnable bias.
1313
- `init`: Weights' initializer.
1414
15+
# Example
1516
16-
The input to the layer is a node feature array `X`
17-
of size `(num_features, num_nodes)`.
17+
```jldoctest
18+
julia> gc = GCNConv(1024=>256, relu)
19+
GCNConv(1024 => 256, relu)
20+
```
21+
22+
See also [`WithGraph`](@ref) for training layer with fixed graph or subgraph.
1823
"""
19-
struct GCNConv{A<:AbstractMatrix, B, F, S<:AbstractFeaturedGraph} <: AbstractGraphLayer
24+
struct GCNConv{A<:AbstractMatrix,B,F}
2025
weight::A
2126
bias::B
2227
σ::F
23-
fg::S
2428
end
2529

26-
function GCNConv(fg::AbstractFeaturedGraph, ch::Pair{Int,Int}, σ=identity;
30+
function GCNConv(ch::Pair{Int,Int}, σ=identity;
2731
init=glorot_uniform, bias::Bool=true)
2832
in, out = ch
2933
W = init(out, in)
3034
b = Flux.create_bias(W, bias, out)
31-
GCNConv(W, b, σ, fg)
35+
GCNConv(W, b, σ)
3236
end
3337

34-
GCNConv(ch::Pair{Int,Int}, σ = identity; kwargs...) =
35-
GCNConv(NullGraph(), ch, σ; kwargs...)
36-
3738
@functor GCNConv
3839

39-
Flux.trainable(l::GCNConv) = (l.weight, l.bias)
40+
(l::GCNConv)(Ã::AbstractArray, x::AbstractArray) = l.σ.(l.weight * x *.+ l.bias)
4041

41-
function (l::GCNConv)(fg::ConcreteFeaturedGraph, x::AbstractMatrix)
42+
function (l::GCNConv)(fg::AbstractFeaturedGraph)
43+
nf = node_feature(fg)
4244
= Zygote.ignore() do
43-
GraphSignals.normalized_adjacency_matrix(fg, eltype(x); selfloop=true)
45+
GraphSignals.normalized_adjacency_matrix(fg, eltype(nf); selfloop=true)
4446
end
45-
l.σ.(l.weight * x *.+ l.bias)
47+
return FeaturedGraph(fg, nf = l(Ã, nf))
4648
end
4749

48-
(l::GCNConv)(fg::AbstractFeaturedGraph) = FeaturedGraph(fg, nf = l(fg, node_feature(fg)))
50+
function (wg::WithGraph{<:GCNConv})(X::AbstractArray)
51+
N = size(X, 2)
52+
wg.subgraph != (:) && N != length(wg.subgraph) &&
53+
throw(ArgumentError("Layer with subgraph expecting subset of features, got #V=$N but #V for subgraph $(length(wg.subgraph))."))
54+
= Zygote.ignore() do
55+
GraphSignals.normalized_adjacency_matrix(wg.fg, eltype(X); selfloop=true)
56+
end
57+
return wg.layer(Ã[wg.subgraph, wg.subgraph], X)
58+
end
4959

5060
function Base.show(io::IO, l::GCNConv)
5161
out, in = size(l.weight)

src/layers/misc.jl

Lines changed: 0 additions & 32 deletions
This file was deleted.

src/layers/utils.jl

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
"""
2+
WithGraph(layer, fg, [subgraph=:])
3+
4+
Train GNN layers with fixed graph.
5+
6+
# Arguments
7+
8+
- `layer`: A GNN layer.
9+
- `fg`: A fixed `FeaturedGraph` to train with.
10+
- `subgraph`: Node indeices to get a subgraph from `fg`.
11+
12+
# Example
13+
14+
```jldoctest
15+
julia> adj = [0 1 0 1;
16+
1 0 1 0;
17+
0 1 0 1;
18+
1 0 1 0];
19+
20+
julia> fg = FeaturedGraph(adj);
21+
22+
julia> gc = WithGraph(GCNConv(1024=>256), fg)
23+
WithGraph(GCNConv(1024 => 256), FeaturedGraph(#V=4, #E=4))
24+
25+
julia> subgraph = [1, 2, 4] # specify subgraph nodes
26+
27+
julia> gc = WithGraph(GCNConv(1024=>256), fg, subgraph)
28+
WithGraph(GCNConv(1024 => 256), FeaturedGraph(#V=4, #E=4), subgraph=[1, 2, 4])
29+
```
30+
"""
31+
struct WithGraph{L,G<:AbstractFeaturedGraph,S}
32+
layer::L
33+
fg::G
34+
subgraph::S
35+
end
36+
37+
@functor WithGraph
38+
39+
Flux.trainable(l::WithGraph) = (l.layer, )
40+
41+
WithGraph(layer, fg::AbstractFeaturedGraph) = WithGraph(layer, fg, :)
42+
43+
function Base.show(io::IO, l::WithGraph)
44+
print(io, "WithGraph(")
45+
print(io, l.layer, ", ")
46+
print(io, "FeaturedGraph(#V=", nv(l.fg), ", #E=", ne(l.fg), ")")
47+
l.subgraph == (:) || print(io, ", subgraph=", l.subgraph)
48+
print(io, ")")
49+
end
50+
51+
"""
52+
GraphParallel(; node_layer=identity, edge_layer=identity, global_layer=identity)
53+
54+
Passing features in `FeaturedGraph` in parallel. It takes `FeaturedGraph` as input
55+
and it can be specified by assigning layers for specific (node, edge and global) features.
56+
57+
# Arguments
58+
59+
- `node_layer`: A regular Flux layer for passing node features.
60+
- `edge_layer`: A regular Flux layer for passing edge features.
61+
- `global_layer`: A regular Flux layer for passing global features.
62+
63+
# Example
64+
65+
```jldoctest
66+
julia> l = GraphParallel(
67+
node_layer=Dropout(0.5),
68+
global_layer=Dense(10, 5)
69+
)
70+
```
71+
"""
72+
struct GraphParallel{N,E,G}
73+
node_layer::N
74+
edge_layer::E
75+
global_layer::G
76+
end
77+
78+
@functor GraphParallel
79+
80+
GraphParallel(; node_layer=identity, edge_layer=identity, global_layer=identity) =
81+
GraphParallel(node_layer, edge_layer, global_layer)
82+
83+
function (l::GraphParallel)(fg::FeaturedGraph)
84+
nf = l.node_layer(node_feature(fg))
85+
ef = l.edge_layer(edge_feature(fg))
86+
gf = l.global_layer(global_feature(fg))
87+
return FeaturedGraph(fg, nf=nf, ef=ef, gf=gf)
88+
end
89+
90+
function Base.show(io::IO, l::GraphParallel)
91+
print(io, "GraphParallel(")
92+
print(io, "node_layer=", l.node_layer)
93+
print(io, ", edge_layer=", l.edge_layer)
94+
print(io, ", global_layer=", l.global_layer)
95+
print(io, ")")
96+
end

test/cuda/conv.jl

Lines changed: 33 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
T = Float32
33
in_channel = 3
44
out_channel = 5
5+
56
N = 4
67
adj = T[0 1 0 1;
78
1 0 1 0;
@@ -11,17 +12,29 @@
1112
fg = FeaturedGraph(adj)
1213

1314
@testset "GCNConv" begin
14-
gc = GCNConv(fg, in_channel=>out_channel) |> gpu
15-
@test size(gc.weight) == (out_channel, in_channel)
16-
@test size(gc.bias) == (out_channel,)
17-
@test collect(GraphSignals.adjacency_matrix(gc.fg)) == adj
15+
X = rand(T, in_channel, N)
1816

19-
X = rand(in_channel, N) |> gpu
20-
Y = gc(X)
21-
@test size(Y) == (out_channel, N)
17+
@testset "layer without graph" begin
18+
gc = GCNConv(in_channel=>out_channel) |> gpu
19+
@test size(gc.weight) == (out_channel, in_channel)
20+
@test size(gc.bias) == (out_channel,)
2221

23-
g = Zygote.gradient(() -> sum(gc(X)), Flux.params(gc))
24-
@test length(g.grads) == 2
22+
fg = FeaturedGraph(adj, nf=X) |> gpu
23+
fg_ = gc(fg)
24+
@test size(node_feature(fg_)) == (out_channel, N)
25+
26+
g = Zygote.gradient(() -> sum(node_feature(gc(fg))), Flux.params(gc))
27+
@test length(g.grads) == 4
28+
end
29+
30+
@testset "layer with fixed graph" begin
31+
gc = WithGraph(GCNConv(in_channel=>out_channel), fg) |> gpu
32+
Y = gc(X |> gpu)
33+
@test size(Y) == (out_channel, N)
34+
35+
g = Zygote.gradient(() -> sum(gc(X |> gpu)), Flux.params(gc))
36+
@test length(g.grads) == 3
37+
end
2538
end
2639

2740

@@ -43,19 +56,19 @@
4356
@test length(g.grads) == 2
4457
end
4558

46-
@testset "GraphConv" begin
47-
gc = GraphConv(fg, in_channel=>out_channel) |> gpu
48-
@test size(gc.weight1) == (out_channel, in_channel)
49-
@test size(gc.weight2) == (out_channel, in_channel)
50-
@test size(gc.bias) == (out_channel,)
59+
# @testset "GraphConv" begin
60+
# gc = GraphConv(fg, in_channel=>out_channel) |> gpu
61+
# @test size(gc.weight1) == (out_channel, in_channel)
62+
# @test size(gc.weight2) == (out_channel, in_channel)
63+
# @test size(gc.bias) == (out_channel,)
5164

52-
X = rand(in_channel, N) |> gpu
53-
Y = gc(X)
54-
@test size(Y) == (out_channel, N)
65+
# X = rand(in_channel, N) |> gpu
66+
# Y = gc(X)
67+
# @test size(Y) == (out_channel, N)
5568

56-
g = Zygote.gradient(() -> sum(gc(X)), Flux.params(gc))
57-
@test length(g.grads) == 3
58-
end
69+
# g = Zygote.gradient(() -> sum(gc(X)), Flux.params(gc))
70+
# @test length(g.grads) == 3
71+
# end
5972

6073
@testset "GATConv" begin
6174
adj = T[1 1 0 1;

test/layers/conv.jl

Lines changed: 22 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -15,33 +15,16 @@
1515
@testset "GCNConv" begin
1616
X = rand(T, in_channel, N)
1717
Xt = transpose(rand(T, N, in_channel))
18-
@testset "layer with graph" begin
19-
gc = GCNConv(fg, in_channel=>out_channel)
20-
@test size(gc.weight) == (out_channel, in_channel)
21-
@test size(gc.bias) == (out_channel,)
22-
@test GraphSignals.adjacency_matrix(gc.fg) == adj
23-
24-
Y = gc(X)
25-
@test size(Y) == (out_channel, N)
26-
27-
# Test with transposed features
28-
Y = gc(Xt)
29-
@test size(Y) == (out_channel, N)
30-
31-
g = Zygote.gradient(() -> sum(gc(X)), Flux.params(gc))
32-
@test length(g.grads) == 2
33-
end
3418

3519
@testset "layer without graph" begin
3620
gc = GCNConv(in_channel=>out_channel)
3721
@test size(gc.weight) == (out_channel, in_channel)
3822
@test size(gc.bias) == (out_channel,)
39-
@test !has_graph(gc.fg)
4023

4124
fg = FeaturedGraph(adj, nf=X)
4225
fg_ = gc(fg)
4326
@test size(node_feature(fg_)) == (out_channel, N)
44-
@test_throws ArgumentError gc(X)
27+
@test_throws MethodError gc(X)
4528

4629
# Test with transposed features
4730
fgt = FeaturedGraph(adj, nf=Xt)
@@ -52,6 +35,27 @@
5235
@test length(g.grads) == 4
5336
end
5437

38+
@testset "layer with fixed graph" begin
39+
gc = WithGraph(GCNConv(in_channel=>out_channel), fg)
40+
Y = gc(X)
41+
@test size(Y) == (out_channel, N)
42+
43+
# Test with transposed features
44+
Y = gc(Xt)
45+
@test size(Y) == (out_channel, N)
46+
47+
g = Zygote.gradient(() -> sum(gc(X)), Flux.params(gc))
48+
@test length(g.grads) == 3
49+
end
50+
51+
@testset "layer with subgraph" begin
52+
X = rand(T, in_channel, 3)
53+
subgraph = [1,2,4]
54+
gc = WithGraph(GCNConv(in_channel=>out_channel), fg, subgraph)
55+
Y = gc(X)
56+
@test size(Y) == (out_channel, 3)
57+
end
58+
5559
@testset "bias=false" begin
5660
@test length(Flux.params(GCNConv(2=>3))) == 2
5761
@test length(Flux.params(GCNConv(2=>3, bias=false))) == 1

0 commit comments

Comments
 (0)