Skip to content

Commit bd24959

Browse files
committed
semi-supervised GCN baseline
fix
1 parent 0e95727 commit bd24959

File tree

6 files changed

+142
-78
lines changed

6 files changed

+142
-78
lines changed

examples/gcn.jl

Lines changed: 119 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1,48 +1,127 @@
1+
using CUDA
2+
using Flux
3+
using Flux: onehotbatch, onecold
4+
using Flux.Losses: logitcrossentropy
5+
using Flux.Data: DataLoader
16
using GeometricFlux
7+
using GeometricFlux.Datasets
28
using GraphSignals
3-
using Flux
4-
using Flux: onehotbatch, onecold, logitcrossentropy, throttle
5-
using Flux: @epochs
6-
using JLD2
9+
using Logging: with_logger
10+
using Parameters: @with_kw
11+
using ProgressMeter: Progress, next!
712
using Statistics
8-
using SparseArrays
9-
using Graphs.SimpleGraphs
10-
using CUDA
13+
using Random
14+
15+
CUDA.allowscalar(false)
16+
17+
function load_data(dataset, batch_size)
18+
# (train_X, train_y) dim: (num_features, target_dim) × 140
19+
train_X, train_y = map(x->Matrix(x), traindata(Planetoid(), dataset))
20+
# (test_X, test_y) dim: (num_features, target_dim) × 1000
21+
test_X, test_y = map(x->Matrix(x), testdata(Planetoid(), dataset))
22+
g = graphdata(Planetoid(), dataset)
23+
train_idx = train_indices(Planetoid(), dataset)
24+
test_idx = test_indices(Planetoid(), dataset)
25+
26+
train_data = [(subgraph(FeaturedGraph(g, nf=train_X), train_idx), train_y) for _ in 1:100];
27+
test_data = [(subgraph(FeaturedGraph(g, nf=test_X), test_idx), test_y) for _ in 1:100];
28+
train_batch = Flux.batch(train_data)
29+
test_batch = Flux.batch(test_data)
30+
31+
train_loader = DataLoader(train_batch, batchsize=batch_size, shuffle=true)
32+
test_loader = DataLoader(test_batch, batchsize=batch_size, shuffle=true)
33+
return train_loader, test_loader
34+
end
35+
36+
@with_kw mutable struct Args
37+
η = 0.01 # learning rate
38+
λ = 5f-4 # regularization paramater
39+
batch_size = 32 # batch size
40+
num_nodes = 2708 # number of nodes for graph
41+
epochs = 200 # number of epochs
42+
seed = 0 # random seed
43+
cuda = true # use GPU
44+
input_dim = 1433 # input dimension
45+
hidden_dim = 16 # hidden dimension
46+
target_dim = 7 # target dimension
47+
end
1148

12-
@load "data/cora_features.jld2" features
13-
@load "data/cora_labels.jld2" labels
14-
@load "data/cora_graph.jld2" g
15-
16-
num_nodes = 2708
17-
num_features = 1433
18-
hidden = 16
19-
target_catg = 7
20-
epochs = 200
21-
λ = 5e-4
22-
23-
## Preprocessing data
24-
train_X = Matrix{Float32}(features) |> gpu # dim: num_features * num_nodes
25-
train_y = Matrix{Float32}(labels) |> gpu # dim: target_catg * num_nodes
26-
fg = FeaturedGraph(g) # pass to gpu together in model layers
27-
28-
## Model
29-
model = Chain(GCNConv(fg, num_features=>hidden, relu),
30-
Dropout(0.5),
31-
GCNConv(fg, hidden=>target_catg),
32-
) |> gpu;
33-
# do not show model architecture, showing CuSparseMatrix will trigger errors
34-
35-
## Loss
49+
## Loss: cross entropy with first layer L2 regularization
3650
l2norm(x) = sum(abs2, x)
37-
# cross entropy with first layer L2 regularization
38-
loss(x, y) = logitcrossentropy(model(x), y) + λ*sum(l2norm, Flux.params(model[1]))
39-
accuracy(x, y) = mean(onecold(softmax(cpu(model(x)))) .== onecold(cpu(y)))
51+
function model_loss(model, λ, batch)
52+
loss = 0.f0
53+
for (x, y) in batch
54+
loss += logitcrossentropy(model(x), y)
55+
loss += λ*sum(l2norm, Flux.params(model[1]))
56+
end
57+
return loss
58+
end
59+
60+
function accuracy(model, batch::AbstractVector)
61+
return mean(mean(onecold(softmax(cpu(model(x)))) .== onecold(cpu(y))) for (x, y) in batch)
62+
end
63+
64+
accuracy(model, loader::DataLoader, device) = mean(accuracy(model, batch |> device) for batch in loader)
65+
66+
function train(; kws...)
67+
# load hyperparamters
68+
args = Args(; kws...)
69+
args.seed > 0 && Random.seed!(args.seed)
70+
71+
# GPU config
72+
if args.cuda && CUDA.has_cuda()
73+
device = gpu
74+
@info "Training on GPU"
75+
else
76+
device = cpu
77+
@info "Training on CPU"
78+
end
79+
80+
# load Cora from Planetoid dataset
81+
train_loader, test_loader = load_data(:cora, args.batch_size)
82+
83+
# build model
84+
model = Chain(
85+
GCNConv(args.input_dim=>args.hidden_dim, relu),
86+
GraphParallel(node_layer=Dropout(0.5)),
87+
GCNConv(args.hidden_dim=>args.target_dim),
88+
node_feature,
89+
) |> device
90+
91+
# ADAM optimizer
92+
opt = ADAM(args.η)
93+
94+
# parameters
95+
ps = Flux.params(model)
96+
97+
# training
98+
train_steps = 0
99+
@info "Start Training, total $(args.epochs) epochs"
100+
for epoch = 1:args.epochs
101+
@info "Epoch $(epoch)"
102+
progress = Progress(length(train_loader))
103+
104+
for batch in train_loader
105+
loss, back = Flux.pullback(ps) do
106+
model_loss(model, args.λ, batch |> device)
107+
end
108+
train_acc = accuracy(model, train_loader, device)
109+
test_acc = accuracy(model, test_loader, device)
110+
grad = back(1f0)
111+
Flux.Optimise.update!(opt, ps, grad)
112+
113+
# progress meter
114+
next!(progress; showvalues=[
115+
(:loss, loss),
116+
(:train_accuracy, train_acc),
117+
(:test_accuracy, test_acc)
118+
])
40119

120+
train_steps += 1
121+
end
122+
end
41123

42-
## Training
43-
ps = Flux.params(model)
44-
train_data = [(train_X, train_y)]
45-
opt = ADAM(0.01)
46-
evalcb() = @show(accuracy(train_X, train_y))
124+
return model, args
125+
end
47126

48-
@epochs epochs Flux.train!(loss, ps, train_data, opt, cb=throttle(evalcb, 10))
127+
model, args = train()

src/GeometricFlux.jl

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,6 @@ using Zygote
2020

2121
import Word2Vec: word2vec, wordvectors, get_vector
2222

23-
const ConcreteFeaturedGraph = Union{FeaturedGraph,FeaturedSubgraph}
24-
2523
export
2624
# layers/graphlayers
2725
AbstractGraphLayer,

src/layers/conv.jl

Lines changed: 11 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ julia> gc = GCNConv(1024=>256, relu)
1919
GCNConv(1024 => 256, relu)
2020
```
2121
22-
See also [`WithGraph`](@ref) for training layer with fixed graph or subgraph.
22+
See also [`WithGraph`](@ref) for training layer with fixed graph.
2323
"""
2424
struct GCNConv{A<:AbstractMatrix,B,F}
2525
weight::A
@@ -37,24 +37,21 @@ end
3737

3838
@functor GCNConv
3939

40-
(l::GCNConv)(Ã::AbstractArray, x::AbstractArray) = l.σ.(l.weight * x *.+ l.bias)
40+
(l::GCNConv)(Ã::AbstractMatrix, x::AbstractMatrix) = l.σ.(l.weight * x *.+ l.bias)
4141

4242
function (l::GCNConv)(fg::AbstractFeaturedGraph)
4343
nf = node_feature(fg)
4444
= Zygote.ignore() do
4545
GraphSignals.normalized_adjacency_matrix(fg, eltype(nf); selfloop=true)
4646
end
47-
return FeaturedGraph(fg, nf = l(Ã, nf))
47+
return ConcreteFeaturedGraph(fg, nf = l(Ã, nf))
4848
end
4949

5050
function (wg::WithGraph{<:GCNConv})(X::AbstractArray)
51-
N = size(X, 2)
52-
wg.subgraph != (:) && N != length(wg.subgraph) &&
53-
throw(ArgumentError("Layer with subgraph expecting subset of features, got #V=$N but #V for subgraph $(length(wg.subgraph))."))
5451
= Zygote.ignore() do
5552
GraphSignals.normalized_adjacency_matrix(wg.fg, eltype(X); selfloop=true)
5653
end
57-
return wg.layer(Ã[wg.subgraph, wg.subgraph], X)
54+
return wg.layer(Ã, X)
5855
end
5956

6057
function Base.show(io::IO, l::GCNConv)
@@ -101,7 +98,7 @@ ChebConv(ch::Pair{Int,Int}, k::Int; kwargs...) =
10198

10299
Flux.trainable(l::ChebConv) = (l.weight, l.bias)
103100

104-
function (c::ChebConv)(fg::ConcreteFeaturedGraph, X::AbstractMatrix{T}) where T
101+
function (c::ChebConv)(fg::AbstractFeaturedGraph, X::AbstractMatrix{T}) where T
105102
GraphSignals.check_num_nodes(fg, X)
106103
@assert size(X, 1) == size(c.weight, 2) "Input feature size must match input channel size."
107104

@@ -175,7 +172,7 @@ message(gc::GraphConv, x_i, x_j::AbstractVector, e_ij) = gc.weight2 * x_j
175172

176173
update(gc::GraphConv, m::AbstractVector, x::AbstractVector) = gc.σ.(gc.weight1*x .+ m .+ gc.bias)
177174

178-
function (gc::GraphConv)(fg::ConcreteFeaturedGraph, x::AbstractMatrix)
175+
function (gc::GraphConv)(fg::AbstractFeaturedGraph, x::AbstractMatrix)
179176
# GraphSignals.check_num_nodes(fg, x)
180177
_, x, _ = propagate(gc, fg, edge_feature(fg), x, global_feature(fg), +)
181178
x
@@ -290,7 +287,7 @@ function update_batch_vertex(gat::GATConv, ::AbstractFeaturedGraph, M::AbstractM
290287
return M
291288
end
292289

293-
function (gat::GATConv)(fg::ConcreteFeaturedGraph, X::AbstractMatrix)
290+
function (gat::GATConv)(fg::AbstractFeaturedGraph, X::AbstractMatrix)
294291
GraphSignals.check_num_nodes(fg, X)
295292
_, X, _ = propagate(gat, fg, edge_feature(fg), X, global_feature(fg), +)
296293
return X
@@ -349,7 +346,7 @@ message(ggc::GatedGraphConv, x_i, x_j::AbstractVector, e_ij) = x_j
349346
update(ggc::GatedGraphConv, m::AbstractVector, x) = m
350347

351348

352-
function (ggc::GatedGraphConv)(fg::ConcreteFeaturedGraph, H::AbstractMatrix{S}) where {T<:AbstractVector,S<:Real}
349+
function (ggc::GatedGraphConv)(fg::AbstractFeaturedGraph, H::AbstractMatrix{S}) where {T<:AbstractVector,S<:Real}
353350
GraphSignals.check_num_nodes(fg, H)
354351
m, n = size(H)
355352
@assert (m <= ggc.out_ch) "number of input features must less or equals to output features."
@@ -406,7 +403,7 @@ Flux.trainable(l::EdgeConv) = (l.nn,)
406403
message(ec::EdgeConv, x_i::AbstractVector, x_j::AbstractVector, e_ij) = ec.nn(vcat(x_i, x_j .- x_i))
407404
update(ec::EdgeConv, m::AbstractVector, x) = m
408405

409-
function (ec::EdgeConv)(fg::ConcreteFeaturedGraph, X::AbstractMatrix)
406+
function (ec::EdgeConv)(fg::AbstractFeaturedGraph, X::AbstractMatrix)
410407
GraphSignals.check_num_nodes(fg, X)
411408
_, X, _ = propagate(ec, fg, edge_feature(fg), X, global_feature(fg), ec.aggr)
412409
X
@@ -457,7 +454,7 @@ Flux.trainable(g::GINConv) = (fg=g.fg, nn=g.nn)
457454
message(g::GINConv, x_i::AbstractVector, x_j::AbstractVector) = x_j
458455
update(g::GINConv, m::AbstractVector, x) = g.nn((1 + g.eps) * x + m)
459456

460-
function (g::GINConv)(fg::ConcreteFeaturedGraph, X::AbstractMatrix)
457+
function (g::GINConv)(fg::AbstractFeaturedGraph, X::AbstractMatrix)
461458
gf = graph(fg)
462459
GraphSignals.check_num_nodes(gf, X)
463460
_, X, _ = propagate(g, fg, edge_feature(fg), X, global_feature(fg), +)
@@ -526,7 +523,7 @@ message(c::CGConv,
526523
end
527524
update(c::CGConv, m::AbstractVector, x) = x + m
528525

529-
function (c::CGConv)(fg::ConcreteFeaturedGraph, X::AbstractMatrix, E::AbstractMatrix)
526+
function (c::CGConv)(fg::AbstractFeaturedGraph, X::AbstractMatrix, E::AbstractMatrix)
530527
GraphSignals.check_num_nodes(fg, X)
531528
GraphSignals.check_num_edges(fg, E)
532529
_, Y, _ = propagate(c, fg, E, X, global_feature(fg), +)

src/layers/gn.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ end
5454

5555
function propagate(gn::GraphNet, fg::AbstractFeaturedGraph, naggr=nothing, eaggr=nothing, vaggr=nothing)
5656
E, V, u = propagate(gn, fg, edge_feature(fg), node_feature(fg), global_feature(fg), naggr, eaggr, vaggr)
57-
FeaturedGraph(fg, nf=V, ef=E, gf=u)
57+
return FeaturedGraph(fg, nf=V, ef=E, gf=u)
5858
end
5959

6060
"""

src/layers/utils.jl

Lines changed: 4 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,12 @@
11
"""
2-
WithGraph(layer, fg, [subgraph=:])
2+
WithGraph(layer, fg)
33
44
Train GNN layers with fixed graph.
55
66
# Arguments
77
88
- `layer`: A GNN layer.
99
- `fg`: A fixed `FeaturedGraph` to train with.
10-
- `subgraph`: Node indeices to get a subgraph from `fg`.
1110
1211
# Example
1312
@@ -21,30 +20,21 @@ julia> fg = FeaturedGraph(adj);
2120
2221
julia> gc = WithGraph(GCNConv(1024=>256), fg)
2322
WithGraph(GCNConv(1024 => 256), FeaturedGraph(#V=4, #E=4))
24-
25-
julia> subgraph = [1, 2, 4] # specify subgraph nodes
26-
27-
julia> gc = WithGraph(GCNConv(1024=>256), fg, subgraph)
28-
WithGraph(GCNConv(1024 => 256), FeaturedGraph(#V=4, #E=4), subgraph=[1, 2, 4])
2923
```
3024
"""
31-
struct WithGraph{L,G<:AbstractFeaturedGraph,S}
25+
struct WithGraph{L,G<:AbstractFeaturedGraph}
3226
layer::L
3327
fg::G
34-
subgraph::S
3528
end
3629

3730
@functor WithGraph
3831

3932
Flux.trainable(l::WithGraph) = (l.layer, )
4033

41-
WithGraph(layer, fg::AbstractFeaturedGraph) = WithGraph(layer, fg, :)
42-
4334
function Base.show(io::IO, l::WithGraph)
4435
print(io, "WithGraph(")
4536
print(io, l.layer, ", ")
4637
print(io, "FeaturedGraph(#V=", nv(l.fg), ", #E=", ne(l.fg), ")")
47-
l.subgraph == (:) || print(io, ", subgraph=", l.subgraph)
4838
print(io, ")")
4939
end
5040

@@ -80,11 +70,11 @@ end
8070
GraphParallel(; node_layer=identity, edge_layer=identity, global_layer=identity) =
8171
GraphParallel(node_layer, edge_layer, global_layer)
8272

83-
function (l::GraphParallel)(fg::FeaturedGraph)
73+
function (l::GraphParallel)(fg::AbstractFeaturedGraph)
8474
nf = l.node_layer(node_feature(fg))
8575
ef = l.edge_layer(edge_feature(fg))
8676
gf = l.global_layer(global_feature(fg))
87-
return FeaturedGraph(fg, nf=nf, ef=ef, gf=gf)
77+
return ConcreteFeaturedGraph(fg, nf=nf, ef=ef, gf=gf)
8878
end
8979

9080
function Base.show(io::IO, l::GraphParallel)

0 commit comments

Comments
 (0)