Skip to content

Commit 2f69b04

Browse files
committed
semisupervised gcn example
1 parent bd24959 commit 2f69b04

File tree

5 files changed

+282
-6
lines changed

5 files changed

+282
-6
lines changed

docs/make.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,10 @@ makedocs(
2121
"Building layers" => "basics/layers.md",
2222
"Graph passing" => "basics/passgraph.md"],
2323
"Cooperate with Flux layers" => "cooperate.md",
24+
"Tutorials" =>
25+
[
26+
"Semi-supervised learning with GCN" => "tutorials/semisupervised_gcn.md",
27+
],
2428
"Abstractions" =>
2529
["Message passing scheme" => "abstractions/msgpass.md",
2630
"Graph network block" => "abstractions/gn.md"],
Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,135 @@
1+
# Semi-supervised Learning with Graph Convolution Networks (GCN)
2+
3+
Graph convolution networks (GCN) have been considered as the first step to graph neural networks (GNN). This example will go through how to train a vanilla GCN.
4+
5+
## Semi-supervised Learning in Graph Neural Networks
6+
7+
The semi-supervised learning task defines a learning by given features and labels for only partial nodes in a graph. We train features and labels for partial nodes, and test the model for another partial nodes in graph.
8+
9+
## Node Classification task
10+
11+
In this task, we learn a node classification task which learns a model to predict labels for each node in a graph. In GCN network, node features are given and the model outputs node labels.
12+
13+
## Step 1: Load Dataset
14+
15+
GeometricFlux provides planetoid dataset in `GeometricFlux.Datasets`, which is provided by GraphMLDatasets. Planetoid dataset has three sub-datasets: Cora, Citeseer, PubMed. We demonstrate Cora dataset in this example. `traindata` provides the functionality for loading training data from various kinds of datasets. Dataset can be specified by the first argument, and the second for sub-datasets.
16+
17+
```julia
18+
using GeometricFlux.Datasets
19+
20+
train_X, train_y = traindata(Planetoid(), :cora)
21+
```
22+
23+
`traindata` returns a pre-defined training features and labels. These features are node features.
24+
25+
```julia
26+
train_X, train_y = map(x->Matrix(x), traindata(Planetoid(), :cora))
27+
```
28+
29+
We can load graph from `graphdata`, and the graph is preprocessed into `SimpleGraph` type, which is provided by Graphs.
30+
31+
```julia
32+
g = graphdata(Planetoid(), :cora)
33+
train_idx = train_indices(Planetoid(), :cora)
34+
```
35+
36+
We need node indices to index a subgraph from original graph. `train_indices` gives node indices for training.
37+
38+
## Step 2: Wrapping Graph and Features into `FeaturedGraph`
39+
40+
`FeaturedGraph` is a container for holding a graph, node features, edge features and global features. It is provided by GraphSignals. To wrap graph and node features into `FeaturedGraph`, graph `g` should be placed as the first argument and `nf` is to specify node features.
41+
42+
```julia
43+
using GraphSignals
44+
45+
FeaturedGraph(g, nf=train_X)
46+
```
47+
48+
If we want to get a subgraph from a `FeaturedGraph` object, we call `subgraph` and provide node indices `train_idx` as second argument.
49+
50+
```julia
51+
subgraph(FeaturedGraph(g, nf=train_X), train_idx)
52+
```
53+
54+
## Step 3: Build a GCN model
55+
56+
A GCn model is composed of two layers of `GCNConv` and the activation function for first layer is `relu`. In the middle, a `Dropout` layer is placed. We need a `GraphParallel` to integrate with regular Flux layer, and it specifies node features go to `node_layer=Dropout(0.5)`.
57+
58+
```julia
59+
model = Chain(
60+
GCNConv(input_dim=>hidden_dim, relu),
61+
GraphParallel(node_layer=Dropout(0.5)),
62+
GCNConv(hidden_dim=>target_dim),
63+
node_feature,
64+
)
65+
```
66+
67+
Since the model input is a `FeaturedGraph` object, the model output a `FeaturedGraph` object as well. In the end of model, we get node features out from a `FeaturedGraph` object using `node_feature`.
68+
69+
## Step 4: Loss Functions and Accuracy
70+
71+
Then, since it is a node classification task, we define the model loss by `logitcrossentropy`, and a L2 regularization is used. In the vanilla GCN, only first layer is applied to L2 regularization and can be adjusted by hyperparameter `λ`.
72+
73+
```julia
74+
l2norm(x) = sum(abs2, x)
75+
76+
function model_loss(model, λ, batch)
77+
loss = 0.f0
78+
for (x, y) in batch
79+
loss += logitcrossentropy(model(x), y)
80+
loss += λ*sum(l2norm, Flux.params(model[1]))
81+
end
82+
return loss
83+
end
84+
```
85+
86+
Accuracy for a batch and for data loader are provided.
87+
88+
```julia
89+
function accuracy(model, batch::AbstractVector)
90+
return mean(mean(onecold(softmax(cpu(model(x)))) .== onecold(cpu(y))) for (x, y) in batch)
91+
end
92+
93+
accuracy(model, loader::DataLoader, device) = mean(accuracy(model, batch |> device) for batch in loader)
94+
```
95+
96+
## Step 5: Training GCN Model
97+
98+
We train the model with the same process as training a Flux model.
99+
100+
```julia
101+
train_loader, test_loader = load_data(:cora, args.batch_size)
102+
103+
# optimizer
104+
opt = ADAM(args.η)
105+
106+
# parameters
107+
ps = Flux.params(model)
108+
109+
# training
110+
train_steps = 0
111+
@info "Start Training, total $(args.epochs) epochs"
112+
for epoch = 1:args.epochs
113+
@info "Epoch $(epoch)"
114+
115+
for batch in train_loader
116+
grad = gradient(() -> model_loss(model, args.λ, batch |> device), ps)
117+
Flux.Optimise.update!(opt, ps, grad)
118+
train_steps += 1
119+
end
120+
end
121+
```
122+
123+
So far, we complete a basic tutorial for training a GCN model!
124+
125+
For the complete example, please check the script `examples/semisupervised_gcn.jl`.
126+
127+
## Acceleration by Pre-computing Normalized Adjacency Matrix
128+
129+
The training process can be slow in this example. Since we place the graph and features together in `FeaturedGraph` object, `GCNConv` will need to compute a normalized adjacency matrix in the training process. This behavior will lead to long training time. We can accelerate training process by pre-compute normalized adjacency matrix for all `FeaturedGraph` objects. To do so, we can call the following function and it will compute normalized adjacency matrix for `fg` before training. This will reduce the training time.
130+
131+
```julia
132+
GraphSignals.normalized_adjacency_matrix!(fg)
133+
```
134+
135+
Since the normalized adjacency matrix is used in `GCNConv`, we could pre-compute normalized adjacency matrix for it. If a layer doesn't require a normalized adjacency matrix, this step will lead to error.

examples/gcn_with_fixed_graph.jl

Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
using CUDA
2+
using Flux
3+
using Flux: onehotbatch, onecold
4+
using Flux.Losses: logitcrossentropy
5+
using Flux.Data: DataLoader
6+
using GeometricFlux
7+
using GeometricFlux.Datasets
8+
using GraphSignals
9+
using Logging: with_logger
10+
using Parameters: @with_kw
11+
using ProgressMeter: Progress, next!
12+
using Statistics
13+
using Random
14+
15+
function load_data(dataset, batch_size)
16+
# (train_X, train_y) dim: (num_features, target_dim) × 1708
17+
train_X, train_y = map(x -> Matrix(x), alldata(Planetoid(), dataset))
18+
# (test_X, test_y) dim: (num_features, target_dim) × 1000
19+
test_X, test_y = map(x -> Matrix(x), testdata(Planetoid(), dataset))
20+
g = graphdata(Planetoid(), dataset)
21+
train_idx = 1:size(train_X, 2)
22+
test_idx = test_indices(Planetoid(), dataset)
23+
24+
# padding zeros
25+
tr_X = zeros(Float32, size(train_X, 1), size(train_X, 2) + size(test_X, 2))
26+
te_X = zeros(Float32, size(test_X, 1), size(train_X, 2) + size(test_X, 2))
27+
tr_y = zeros(Float32, size(train_y, 1), size(train_y, 2) + size(test_y, 2))
28+
te_y = zeros(Float32, size(test_y, 1), size(train_y, 2) + size(test_y, 2))
29+
tr_X[:, train_idx] .= train_X
30+
te_X[:, test_idx] .= test_X
31+
tr_y[:, train_idx] .= train_y
32+
te_y[:, test_idx] .= test_y
33+
34+
fg = FeaturedGraph(g)
35+
train_data = (repeat(tr_X, outer=(1,1,256)), repeat(tr_y, outer=(1,1,256)))
36+
test_data = (repeat(te_X, outer=(1,1,32)), repeat(te_y, outer=(1,1,32)))
37+
train_loader = DataLoader(train_data, batchsize=batch_size, shuffle=true)
38+
test_loader = DataLoader(test_data, batchsize=batch_size, shuffle=true)
39+
return train_loader, test_loader, fg, train_idx, test_idx
40+
end
41+
42+
@with_kw mutable struct Args
43+
η = 0.01 # learning rate
44+
λ = 5f-4 # regularization paramater
45+
batch_size = 32 # batch size
46+
num_nodes = 2708 # number of nodes for graph
47+
epochs = 200 # number of epochs
48+
seed = 0 # random seed
49+
cuda = true # use GPU
50+
input_dim = 1433 # input dimension
51+
hidden_dim = 16 # hidden dimension
52+
target_dim = 7 # target dimension
53+
end
54+
55+
## Loss: cross entropy with first layer L2 regularization
56+
l2norm(x) = sum(abs2, x)
57+
function model_loss(model, λ, X, y, idx)
58+
loss = logitcrossentropy(model(X)[:,idx,:], y[:,idx,:])
59+
loss += λ*sum(l2norm, Flux.params(model[1]))
60+
return loss
61+
end
62+
63+
function accuracy(model, X::AbstractArray, y::AbstractArray, idx)
64+
return mean(onecold(softmax(cpu(model(X))[:,idx,:])) .== onecold(cpu(y)[:,idx,:]))
65+
end
66+
67+
accuracy(model, loader::DataLoader, device, idx) = mean(accuracy(model, X |> device, y |> device, idx) for (X, y) in loader)
68+
69+
function train(; kws...)
70+
# load hyperparamters
71+
args = Args(; kws...)
72+
args.seed > 0 && Random.seed!(args.seed)
73+
74+
# GPU config
75+
if args.cuda && CUDA.has_cuda()
76+
device = gpu
77+
@info "Training on GPU"
78+
else
79+
device = cpu
80+
@info "Training on CPU"
81+
end
82+
83+
# load Cora from Planetoid dataset
84+
train_loader, test_loader, fg, train_idx, test_idx = load_data(:cora, args.batch_size)
85+
86+
# build model
87+
model = Chain(
88+
WithGraph(fg, GCNConv(args.input_dim=>args.hidden_dim, relu)),
89+
Dropout(0.5),
90+
WithGraph(fg, GCNConv(args.hidden_dim=>args.target_dim)),
91+
) |> device
92+
93+
# ADAM optimizer
94+
opt = ADAM(args.η)
95+
96+
# parameters
97+
ps = Flux.params(model)
98+
99+
# training
100+
train_steps = 0
101+
@info "Start Training, total $(args.epochs) epochs"
102+
for epoch = 1:args.epochs
103+
@info "Epoch $(epoch)"
104+
progress = Progress(length(train_loader))
105+
106+
for (X, y) in train_loader
107+
loss, back = Flux.pullback(ps) do
108+
model_loss(model, args.λ, X |> device, y |> device, train_idx |> device)
109+
end
110+
train_acc = accuracy(model, train_loader, device, train_idx)
111+
test_acc = accuracy(model, test_loader, device, test_idx)
112+
grad = back(1f0)
113+
Flux.Optimise.update!(opt, ps, grad)
114+
115+
# progress meter
116+
next!(progress; showvalues=[
117+
(:loss, loss),
118+
(:train_accuracy, train_acc),
119+
(:test_accuracy, test_acc)
120+
])
121+
122+
train_steps += 1
123+
end
124+
end
125+
126+
return model, args
127+
end
128+
129+
model, args = train()

examples/gcn.jl renamed to examples/semisupervised_gcn.jl

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,7 @@ using ProgressMeter: Progress, next!
1212
using Statistics
1313
using Random
1414

15-
CUDA.allowscalar(false)
16-
17-
function load_data(dataset, batch_size)
15+
function load_data(dataset, batch_size, train_repeats=256, test_repeats=32)
1816
# (train_X, train_y) dim: (num_features, target_dim) × 140
1917
train_X, train_y = map(x->Matrix(x), traindata(Planetoid(), dataset))
2018
# (test_X, test_y) dim: (num_features, target_dim) × 1000
@@ -23,8 +21,8 @@ function load_data(dataset, batch_size)
2321
train_idx = train_indices(Planetoid(), dataset)
2422
test_idx = test_indices(Planetoid(), dataset)
2523

26-
train_data = [(subgraph(FeaturedGraph(g, nf=train_X), train_idx), train_y) for _ in 1:100];
27-
test_data = [(subgraph(FeaturedGraph(g, nf=test_X), test_idx), test_y) for _ in 1:100];
24+
train_data = [(subgraph(FeaturedGraph(g, nf=train_X), train_idx), train_y) for _ in 1:train_repeats]
25+
test_data = [(subgraph(FeaturedGraph(g, nf=test_X), test_idx), test_y) for _ in 1:test_repeats]
2826
train_batch = Flux.batch(train_data)
2927
test_batch = Flux.batch(test_data)
3028

src/layers/conv.jl

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,12 @@ end
3939

4040
(l::GCNConv)(Ã::AbstractMatrix, x::AbstractMatrix) = l.σ.(l.weight * x *.+ l.bias)
4141

42+
function (l::GCNConv)(Ã::AbstractMatrix, X::AbstractArray)
43+
z = NNlib.batched_mul(l.weight, NNlib.batched_mul(X, Ã))
44+
return l.σ.(z .+ l.bias)
45+
end
46+
47+
# For variable graph
4248
function (l::GCNConv)(fg::AbstractFeaturedGraph)
4349
nf = node_feature(fg)
4450
= Zygote.ignore() do
@@ -47,9 +53,13 @@ function (l::GCNConv)(fg::AbstractFeaturedGraph)
4753
return ConcreteFeaturedGraph(fg, nf = l(Ã, nf))
4854
end
4955

56+
# For fixed graph
57+
WithGraph(fg::AbstractFeaturedGraph, l::GCNConv) =
58+
WithGraph(l, GraphSignals.normalized_adjacency_matrix!(fg, eltype(l.weight); selfloop=true))
59+
5060
function (wg::WithGraph{<:GCNConv})(X::AbstractArray)
5161
= Zygote.ignore() do
52-
GraphSignals.normalized_adjacency_matrix(wg.fg, eltype(X); selfloop=true)
62+
GraphSignals.normalized_adjacency_matrix(wg.fg)
5363
end
5464
return wg.layer(Ã, X)
5565
end

0 commit comments

Comments
 (0)