Skip to content

Commit 5a8ea5c

Browse files
committed
fix examples
fix GAE, VGAE, GDEs
1 parent 27c2ce4 commit 5a8ea5c

File tree

5 files changed

+21
-16
lines changed

5 files changed

+21
-16
lines changed

examples/gae.jl

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
using GeometricFlux
2+
using GraphSignals
23
using Flux
34
using Flux: throttle
45
using Flux.Losses: logitbinarycrossentropy
@@ -9,6 +10,8 @@ using SparseArrays
910
using Graphs.SimpleGraphs
1011
using CUDA
1112

13+
CUDA.allowscalar(false)
14+
1215
@load "data/cora_features.jld2" features
1316
@load "data/cora_graph.jld2" g
1417

@@ -20,14 +23,15 @@ target_catg = 7
2023
epochs = 200
2124

2225
## Preprocessing data
23-
fg = FeaturedGraph(g) |> gpu
26+
fg = FeaturedGraph(g) # pass to gpu together in model layers
2427
train_X = Matrix{Float32}(features) |> gpu # dim: num_features * num_nodes
25-
train_y = fg # dim: num_nodes * num_nodes
28+
train_y = fg |> GraphSignals.adjacency_matrix |> gpu # dim: num_nodes * num_nodes
2629

2730
## Model
2831
encoder = Chain(GCNConv(fg, num_features=>hidden1, relu),
2932
GCNConv(fg, hidden1=>hidden2))
30-
model = Chain(GAE(encoder, σ)) |> gpu
33+
model = Chain(GAE(encoder, σ)) |> gpu;
34+
# do not show model architecture, showing CuSparseMatrix will trigger errors
3135

3236
## Loss
3337
loss(x, y) = logitbinarycrossentropy(model(x), y)

examples/gat.jl

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,10 @@ using GeometricFlux
22
using Flux
33
using Flux: onehotbatch, onecold, logitcrossentropy, throttle
44
using Flux: @epochs
5-
using Flux.Data: DataLoader
65
using JLD2
76
using Statistics: mean
87
using SparseArrays
8+
using LinearAlgebra
99
using Graphs.SimpleGraphs
1010
using Graphs: adjacency_matrix
1111
using CUDA
@@ -24,29 +24,30 @@ epochs = 10
2424
## Preprocessing data
2525
train_X = Matrix{Float32}(features) |> gpu # dim: num_features * num_nodes
2626
train_y = Matrix{Float32}(labels) |> gpu # dim: target_catg * num_nodes
27-
adj_mat = Matrix{Float32}(adjacency_matrix(g)) |> gpu
27+
A = Matrix{Int}((adjacency_matrix(g) + I) .≥ 1)
28+
fg = FeaturedGraph(A, :adjm)
2829

2930
## Model
30-
model = Chain(GATConv(g, num_features=>hidden, heads=heads),
31+
model = Chain(GATConv(fg, num_features=>hidden, heads=heads),
3132
Dropout(0.6),
32-
GATConv(g, hidden*heads=>target_catg, heads=heads, concat=false)
33+
GATConv(fg, hidden*heads=>target_catg, heads=heads, concat=false)
3334
) |> gpu
3435
# test model
35-
# @show model(train_X)
36+
@show model(train_X)
3637

3738
## Loss
3839
loss(x, y) = logitcrossentropy(model(x), y)
3940
accuracy(x, y) = mean(onecold(cpu(model(x))) .== onecold(cpu(y)))
4041

4142
# test loss
42-
# @show loss(train_X, train_y)
43+
@show loss(train_X, train_y)
4344

4445
# test gradient
45-
# @show gradient(X -> loss(X, train_y), train_X)
46+
@show gradient(()->loss(train_X, train_y), Flux.params(model))
4647

4748
## Training
4849
ps = Flux.params(model)
49-
train_data = DataLoader(train_X, train_y, batchsize=num_nodes)
50+
train_data = Flux.Data.DataLoader((train_X, train_y), batchsize=num_nodes)
5051
opt = ADAM(0.01)
5152
evalcb() = @show(accuracy(train_X, train_y))
5253

examples/gde.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
using GeometricFlux, Flux, JLD2, SparseArrays, DiffEqFlux, DifferentialEquations
1+
using GeometricFlux, GraphSignals, Flux, JLD2, SparseArrays, DiffEqFlux, DifferentialEquations
22
using Flux: onehotbatch, onecold, logitcrossentropy, throttle
33
using Flux: @epochs
44
using Statistics: mean

examples/gde_gpu.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
using GeometricFlux, Flux, JLD2, SparseArrays, DiffEqFlux, DifferentialEquations
1+
using GeometricFlux, GraphSignals, Flux, JLD2, SparseArrays, DiffEqFlux, DifferentialEquations
22
using Flux: onehotbatch, onecold, logitcrossentropy, throttle
33
using Flux: @epochs
44
using Statistics: mean
@@ -19,7 +19,7 @@ epochs = 40
1919
# Preprocess the data and compute adjacency matrix
2020
train_X = Matrix{Float32}(features) |> gpu # dim: num_features * num_nodes
2121
train_y = Float32.(labels) |> gpu # dim: target_catg * num_nodes
22-
fg = FeaturedGraph(g) |> gpu
22+
fg = FeaturedGraph(g)
2323

2424
# Define the Neural GDE
2525
diffeqarray_to_array(x) = reshape(gpu(x), size(x)[1:2])

examples/vgae.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ epochs = 200
2323
## Preprocessing data
2424
masks = [rand(Float32, num_nodes, num_nodes).>0.1 for i in 1:10]
2525
adj_mat = Matrix{Float32}(adjacency_matrix(g))
26-
train_data = [(FeaturedGraph(adj_mat.*M, Matrix{Float32}(features)), adj_mat) for M in masks]
26+
train_data = [(FeaturedGraph(adj_mat.*M, nf=Matrix{Float32}(features)), adj_mat) for M in masks]
2727

2828
## Model
2929
model = VGAE(GCNConv(num_features=>h_dim, relu;), h_dim, z_dim, σ)
@@ -34,7 +34,7 @@ ps = Flux.params(model)
3434
l2_norm(p) = sum(abs2, p)
3535

3636
function loss(fg, Y, X=node_feature(fg), T=eltype(X), β=one(T), λ=T(0.01); debug=false)
37-
μ̂, logσ̂ = summarize(encoder, fg)
37+
μ̂, logσ̂ = GeometricFlux.summarize(encoder, fg)
3838
Z = node_feature(encoder(fg))
3939
kl_q_p = -T(0.5) * sum(one(T) .+ T(2).*logσ̂ .- μ̂.^2 .- exp.(T(2).*logσ̂))
4040
logp_y_z = -sum(logitbinarycrossentropy(decoder(Z), Y, agg=identity)) / size(Y,2)

0 commit comments

Comments
 (0)