Skip to content

Commit 40c0bd6

Browse files
authored
Merge pull request #205 from FluxML/develop
Add AbstractGraphLayer for layers that accept graphs
2 parents b806968 + 5302997 commit 40c0bd6

File tree

7 files changed

+41
-42
lines changed

7 files changed

+41
-42
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ NNlib = "0.7"
3333
NNlibCUDA = "0.1"
3434
Reexport = "1.1"
3535
Zygote = "0.6"
36-
julia = "1.6"
36+
julia = "1.6 - 1.7"
3737

3838
[extras]
3939
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"

src/GeometricFlux.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,9 @@ using LightGraphs
1515
using Zygote
1616

1717
export
18+
# layers/graphlayers
19+
AbstractGraphLayer,
20+
1821
# layers/gn
1922
GraphNet,
2023

@@ -55,6 +58,7 @@ include("datasets.jl")
5558

5659
include("utils.jl")
5760

61+
include("layers/graphlayers.jl")
5862
include("layers/gn.jl")
5963
include("layers/msgpass.jl")
6064

src/layers/conv.jl

Lines changed: 26 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ Graph convolutional layer.
1616
The input to the layer is a node feature array `X`
1717
of size `(num_features, num_nodes)`.
1818
"""
19-
struct GCNConv{A<:AbstractMatrix, B, F, S<:AbstractFeaturedGraph}
19+
struct GCNConv{A<:AbstractMatrix, B, F, S<:AbstractFeaturedGraph} <: AbstractGraphLayer
2020
weight::A
2121
bias::B
2222
σ::F
@@ -42,7 +42,6 @@ function (l::GCNConv)(fg::FeaturedGraph, x::AbstractMatrix)
4242
end
4343

4444
(l::GCNConv)(fg::FeaturedGraph) = FeaturedGraph(fg, nf = l(fg, node_feature(fg)))
45-
(l::GCNConv)(x::AbstractMatrix) = l(l.fg, x)
4645

4746
function Base.show(io::IO, l::GCNConv)
4847
out, in = size(l.weight)
@@ -66,7 +65,7 @@ Chebyshev spectral graph convolutional layer.
6665
- `bias`: Add learnable bias.
6766
- `init`: Weights' initializer.
6867
"""
69-
struct ChebConv{A<:AbstractArray{<:Number,3}, B, S<:AbstractFeaturedGraph}
68+
struct ChebConv{A<:AbstractArray{<:Number,3}, B, S<:AbstractFeaturedGraph} <: AbstractGraphLayer
7069
weight::A
7170
bias::B
7271
fg::S
@@ -104,7 +103,6 @@ function (c::ChebConv)(fg::FeaturedGraph, X::AbstractMatrix{T}) where T
104103
end
105104

106105
(l::ChebConv)(fg::FeaturedGraph) = FeaturedGraph(fg, nf = l(fg, node_feature(fg)))
107-
(l::ChebConv)(x::AbstractMatrix) = l(l.fg, x)
108106

109107
function Base.show(io::IO, l::ChebConv)
110108
out, in, k = size(l.weight)
@@ -164,7 +162,6 @@ function (gc::GraphConv)(fg::FeaturedGraph, x::AbstractMatrix)
164162
end
165163

166164
(l::GraphConv)(fg::FeaturedGraph) = FeaturedGraph(fg, nf = l(fg, node_feature(fg)))
167-
(l::GraphConv)(x::AbstractMatrix) = l(l.fg, x)
168165

169166
function Base.show(io::IO, l::GraphConv)
170167
in_channel = size(l.weight1, ndims(l.weight1))
@@ -272,7 +269,6 @@ function (gat::GATConv)(fg::FeaturedGraph, X::AbstractMatrix)
272269
end
273270

274271
(l::GATConv)(fg::FeaturedGraph) = FeaturedGraph(fg, nf = l(fg, node_feature(fg)))
275-
(l::GATConv)(x::AbstractMatrix) = l(l.fg, x)
276272

277273
function Base.show(io::IO, l::GATConv)
278274
in_channel = size(l.weight, ndims(l.weight))
@@ -340,7 +336,6 @@ function (ggc::GatedGraphConv)(fg::FeaturedGraph, H::AbstractMatrix{S}) where {T
340336
end
341337

342338
(l::GatedGraphConv)(fg::FeaturedGraph) = FeaturedGraph(fg, nf = l(fg, node_feature(fg)))
343-
(l::GatedGraphConv)(x::AbstractMatrix) = l(l.fg, x)
344339

345340

346341
function Base.show(io::IO, l::GatedGraphConv)
@@ -383,7 +378,6 @@ function (ec::EdgeConv)(fg::FeaturedGraph, X::AbstractMatrix)
383378
end
384379

385380
(l::EdgeConv)(fg::FeaturedGraph) = FeaturedGraph(fg, nf = l(fg, node_feature(fg)))
386-
(l::EdgeConv)(x::AbstractMatrix) = l(l.fg, x)
387381

388382
function Base.show(io::IO, l::EdgeConv)
389383
print(io, "EdgeConv(", l.nn)
@@ -393,34 +387,34 @@ end
393387

394388

395389
"""
396-
GINConv([fg,] nn, [eps])
390+
GINConv([fg,] nn, [eps=0])
397391
398392
Graph Isomorphism Network.
399393
400394
# Arguments
401395
402396
- `fg`: Optionally pass in a FeaturedGraph as input.
403397
- `nn`: A neural network/layer.
404-
- `eps`: Weighting factor. Default 0.
398+
- `eps`: Weighting factor.
405399
406400
The definition of this is as defined in the original paper,
407401
Xu et. al. (2018) https://arxiv.org/abs/1810.00826.
408402
"""
409-
struct GINConv{V<:AbstractFeaturedGraph,R<:Real} <: MessagePassing
410-
fg::V
403+
struct GINConv{G,R} <: MessagePassing
404+
fg::G
411405
nn
412406
eps::R
413-
end
414407

415-
function GINConv(fg::AbstractFeaturedGraph, nn; eps=0f0)
416-
GINConv(fg, nn, eps)
408+
function GINConv(fg::G, nn, eps::R=0f0) where {G<:AbstractFeaturedGraph,R<:Real}
409+
new{G,R}(fg, nn, eps)
410+
end
417411
end
418412

419-
function GINConv(nn; eps=0f0)
413+
function GINConv(nn, eps::Real=0f0)
420414
GINConv(NullGraph(), nn, eps)
421415
end
422416

423-
Flux.trainable(g::GINConv) = (fg=g.fg,nn=g.nn)
417+
Flux.trainable(g::GINConv) = (fg=g.fg, nn=g.nn)
424418

425419
message(g::GINConv, x_i::AbstractVector, x_j::AbstractVector) = x_j
426420
update(g::GINConv, m::AbstractVector, x) = g.nn((1 + g.eps) * x + m)
@@ -434,12 +428,11 @@ function (g::GINConv)(fg::FeaturedGraph, X::AbstractMatrix)
434428
X
435429
end
436430

437-
(l::GINConv)(x::AbstractMatrix) = l(l.fg, x)
438431
(l::GINConv)(fg::FeaturedGraph) = FeaturedGraph(fg.graph, nf = l(fg, node_feature(fg)))
439432

440433

441434
"""
442-
CGConv([fg,] (node_dim, edge_dim), out, init)
435+
CGConv([fg,] (node_dim, edge_dim), out, init, bias=true, as_edge=false)
443436
444437
Crystal Graph Convolutional network. Uses both node and edge features.
445438
@@ -451,18 +444,17 @@ Crystal Graph Convolutional network. Uses both node and edge features.
451444
- `out`: Dimensionality of the output features.
452445
- `init`: Initialization algorithm for each of the weight matrices
453446
- `bias`: Whether or not to learn an additive bias parameter.
447+
- `as_edge`: When call to layer `CGConv(M)`, accept input feature as node features or edge features.
454448
455449
# Usage
456450
457451
You can call `CGConv` in several different ways:
458452
459453
- Pass a FeaturedGraph: `CGConv(fg)`, returns `FeaturedGraph`
460454
- Pass both node and edge features: `CGConv(X, E)`
461-
- Pass one matrix, which can either be node features or edge features: `CGConv(M; edge)`:
462-
`edge` is default false, meaning that `M` denotes node features.
455+
- Pass one matrix, which is determined as node features or edge features by `as_edge` keyword argument.
463456
"""
464-
struct CGConv{V <: AbstractFeaturedGraph, T,
465-
A <: AbstractMatrix{T}, B} <: MessagePassing
457+
struct CGConv{E, V<:AbstractFeaturedGraph, A<:AbstractMatrix, B} <: MessagePassing
466458
fg::V
467459
Wf::A
468460
Ws::A
@@ -472,18 +464,20 @@ end
472464

473465
@functor CGConv
474466

475-
function CGConv(fg::AbstractFeaturedGraph, dims::NTuple{2,Int};
476-
init=glorot_uniform, bias=true)
467+
function CGConv(fg::G, dims::NTuple{2,Int};
468+
init=glorot_uniform, bias=true, as_edge=false) where {G<:AbstractFeaturedGraph}
477469
node_dim, edge_dim = dims
478470
Wf = init(node_dim, 2*node_dim + edge_dim)
479471
Ws = init(node_dim, 2*node_dim + edge_dim)
480472
bf = Flux.create_bias(Wf, bias, node_dim)
481473
bs = Flux.create_bias(Ws, bias, node_dim)
482-
CGConv(fg, Wf, Ws, bf, bs)
474+
T, S = typeof(Wf), typeof(bf)
475+
476+
CGConv{as_edge,G,T,S}(fg, Wf, Ws, bf, bs)
483477
end
484478

485-
function CGConv(dims::NTuple{2,Int}; init=glorot_uniform, bias=true)
486-
CGConv(NullGraph(), dims; init=init, bias=bias)
479+
function CGConv(dims::NTuple{2,Int}; init=glorot_uniform, bias=true, as_edge=false)
480+
CGConv(NullGraph(), dims; init=init, bias=bias, as_edge=as_edge)
487481
end
488482

489483
message(c::CGConv,
@@ -503,10 +497,8 @@ end
503497
(l::CGConv)(fg::FeaturedGraph) = FeaturedGraph(fg, nf=l(fg, node_feature(fg),
504498
edge_feature(fg)),
505499
ef=edge_feature(fg))
506-
(l::CGConv)(M::AbstractMatrix; as_edge=false) =
507-
if as_edge
508-
l(l.fg, node_feature(l.fg), M)
509-
else
510-
l(l.fg, M, edge_feature(l.fg))
511-
end
500+
512501
(l::CGConv)(X::AbstractMatrix, E::AbstractMatrix) = l(l.fg, X, E)
502+
503+
(l::CGConv{true})(M::AbstractMatrix) = l(l.fg, node_feature(l.fg), M)
504+
(l::CGConv{false})(M::AbstractMatrix) = l(l.fg, M, edge_feature(l.fg))

src/layers/gn.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ aggregate(aggr::typeof(max), X) = vec(maximum(X, dims=2))
1010
aggregate(aggr::typeof(min), X) = vec(minimum(X, dims=2))
1111
aggregate(aggr::typeof(mean), X) = vec(aggr(X, dims=2))
1212

13-
abstract type GraphNet end
13+
abstract type GraphNet <: AbstractGraphLayer end
1414

1515
@inline update_edge(gn::GraphNet, e, vi, vj, u) = e
1616
@inline update_vertex(gn::GraphNet, ē, vi, u) = vi

src/layers/graphlayers.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
abstract type AbstractGraphLayer end
2+
3+
(l::AbstractGraphLayer)(x::AbstractMatrix) = l(l.fg, x)

test/cuda/conv.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ N = 4
66
adj = [0 1 0 1;
77
1 0 1 0;
88
0 1 0 1;
9-
1 0 1 0]
9+
1 0 1 0] |> gpu
1010

1111
fg = FeaturedGraph(adj)
1212

@@ -15,7 +15,7 @@ fg = FeaturedGraph(adj)
1515
gc = GCNConv(fg, in_channel=>out_channel) |> gpu
1616
@test size(gc.weight) == (out_channel, in_channel)
1717
@test size(gc.bias) == (out_channel,)
18-
@test Array(adjacency_matrix(gc.fg)) == adj
18+
@test collect(graph(gc.fg)) == Array(adj)
1919

2020
X = rand(in_channel, N) |> gpu
2121
Y = gc(X)
@@ -35,7 +35,7 @@ fg = FeaturedGraph(adj)
3535
cc = ChebConv(fg, in_channel=>out_channel, k) |> gpu
3636
@test size(cc.weight) == (out_channel, in_channel, k)
3737
@test size(cc.bias) == (out_channel,)
38-
@test Array(adjacency_matrix(cc.fg)) == adj
38+
@test collect(graph(cc.fg)) == Array(adj)
3939
@test cc.k == k
4040
@test cc.in_channel == in_channel
4141
@test cc.out_channel == out_channel
@@ -44,8 +44,8 @@ fg = FeaturedGraph(adj)
4444
Y = cc(X)
4545
@test size(Y) == (out_channel, N)
4646

47-
# g = Zygote.gradient(x -> sum(cc(x)), X)[1]
48-
# @test size(g) == size(X)
47+
g = Zygote.gradient(x -> sum(cc(x)), X)[1]
48+
@test size(g) == size(X)
4949

5050
# g = Zygote.gradient(model -> sum(model(X)), cc)[1]
5151
# @test size(g.weight) == size(cc.weight)

test/layers/conv.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -357,7 +357,7 @@ fg_single_vertex = FeaturedGraph(adj_single_vertex)
357357
eps = 0.001
358358

359359
@testset "layer with graph" begin
360-
gc = GINConv(FeaturedGraph(adj), nn, eps=eps)
360+
gc = GINConv(FeaturedGraph(adj), nn, eps)
361361
@test size(gc.nn.layers[1].weight) == (out_channel, in_channel)
362362
@test size(gc.nn.layers[1].bias) == (out_channel, )
363363
@test graph(gc.fg) === adj

0 commit comments

Comments
 (0)