Skip to content

Commit b806968

Browse files
authored
Merge pull request #220 from emsal0/CGConv
Implement CGConv layer.
2 parents e134f54 + 8742cf6 commit b806968

File tree

5 files changed

+116
-0
lines changed

5 files changed

+116
-0
lines changed

docs/src/manual/conv.md

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,3 +120,17 @@ where ``f_{\Theta}`` denotes a neural network parametrized by ``\Theta``, *i.e.*
120120
GINConv
121121
```
122122
Reference: [How Powerful are Graph Neural Networks?](https://arxiv.org/pdf/1810.00826.pdf)
123+
124+
## Crystal Graph Convolutional Network
125+
126+
```math
127+
\textbf{x}_i' = \textbf{x}_i + \sum_{j \in \mathcal{N}(i)} \sigma\left( \textbf{z}_{i,j} \textbf{W}_f + \textbf{b}_f \right) \odot \text{softplus}\left(\textbf{z}_{i,j} \textbf{W}_s + \textbf{b}_s \right)
128+
```
129+
130+
where ``\textbf{z}_{i,j} = [\textbf{x}_i, \textbf{x}_j}, \textbf{e}_{i,j}]`` denotes the concatenation of node features, neighboring node features, and edge features. The operation ``\odot`` represents elementwise multiplication, and ``\sigma`` denotes the sigmoid function.
131+
132+
```@docs
133+
CGConv
134+
```
135+
136+
Reference: [Crystal Graph Convolutional Neural Networks for an Accurate and Interpretable Prediction of Material Properties](https://arxiv.org/pdf/1710.10324.pdf)

src/GeometricFlux.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ export
2929
GatedGraphConv,
3030
EdgeConv,
3131
GINConv,
32+
CGConv,
3233

3334
# layer/pool
3435
GlobalPool,

src/layers/conv.jl

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -436,3 +436,77 @@ end
436436

437437
(l::GINConv)(x::AbstractMatrix) = l(l.fg, x)
438438
(l::GINConv)(fg::FeaturedGraph) = FeaturedGraph(fg.graph, nf = l(fg, node_feature(fg)))
439+
440+
441+
"""
442+
CGConv([fg,] (node_dim, edge_dim), out, init)
443+
444+
Crystal Graph Convolutional network. Uses both node and edge features.
445+
446+
# Arguments
447+
448+
- `fg`: Optional [`FeaturedGraph`] argument(@ref)
449+
- `node_dim`: Dimensionality of the input node features. Also is necessarily the output dimensionality.
450+
- `edge_dim`: Dimensionality of the input edge features.
451+
- `out`: Dimensionality of the output features.
452+
- `init`: Initialization algorithm for each of the weight matrices
453+
- `bias`: Whether or not to learn an additive bias parameter.
454+
455+
# Usage
456+
457+
You can call `CGConv` in several different ways:
458+
459+
- Pass a FeaturedGraph: `CGConv(fg)`, returns `FeaturedGraph`
460+
- Pass both node and edge features: `CGConv(X, E)`
461+
- Pass one matrix, which can either be node features or edge features: `CGConv(M; edge)`:
462+
`edge` is default false, meaning that `M` denotes node features.
463+
"""
464+
struct CGConv{V <: AbstractFeaturedGraph, T,
465+
A <: AbstractMatrix{T}, B} <: MessagePassing
466+
fg::V
467+
Wf::A
468+
Ws::A
469+
bf::B
470+
bs::B
471+
end
472+
473+
@functor CGConv
474+
475+
function CGConv(fg::AbstractFeaturedGraph, dims::NTuple{2,Int};
476+
init=glorot_uniform, bias=true)
477+
node_dim, edge_dim = dims
478+
Wf = init(node_dim, 2*node_dim + edge_dim)
479+
Ws = init(node_dim, 2*node_dim + edge_dim)
480+
bf = Flux.create_bias(Wf, bias, node_dim)
481+
bs = Flux.create_bias(Ws, bias, node_dim)
482+
CGConv(fg, Wf, Ws, bf, bs)
483+
end
484+
485+
function CGConv(dims::NTuple{2,Int}; init=glorot_uniform, bias=true)
486+
CGConv(NullGraph(), dims; init=init, bias=bias)
487+
end
488+
489+
message(c::CGConv,
490+
x_i::AbstractVector, x_j::AbstractVector, e::AbstractVector) = begin
491+
z = vcat(x_i, x_j, e)
492+
σ.(c.Wf * z + c.bf) .* softplus.(c.Ws * z + c.bs)
493+
end
494+
update(c::CGConv, m::AbstractVector, x) = x + m
495+
496+
function (c::CGConv)(fg::FeaturedGraph, X::AbstractMatrix, E::AbstractMatrix)
497+
check_num_nodes(fg, X)
498+
check_num_edges(fg, E)
499+
_, Y = propagate(c, adjacency_list(fg), E, X, +)
500+
Y
501+
end
502+
503+
(l::CGConv)(fg::FeaturedGraph) = FeaturedGraph(fg, nf=l(fg, node_feature(fg),
504+
edge_feature(fg)),
505+
ef=edge_feature(fg))
506+
(l::CGConv)(M::AbstractMatrix; as_edge=false) =
507+
if as_edge
508+
l(l.fg, node_feature(l.fg), M)
509+
else
510+
l(l.fg, M, edge_feature(l.fg))
511+
end
512+
(l::CGConv)(X::AbstractMatrix, E::AbstractMatrix) = l(l.fg, X, E)

src/utils.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,3 +77,7 @@ end
7777
function check_num_nodes(fg::FeaturedGraph, x::AbstractArray)
7878
@assert nv(fg) == size(x, ndims(x))
7979
end
80+
81+
function check_num_edges(fg::FeaturedGraph, e::AbstractArray)
82+
@assert ne(fg) == size(e, ndims(e))
83+
end

test/layers/conv.jl

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
in_channel = 3
2+
in_channel_edge = 1
23
out_channel = 5
34
N = 4
45
T = Float32
@@ -379,4 +380,26 @@ fg_single_vertex = FeaturedGraph(adj_single_vertex)
379380
@test !in(:eps, Flux.trainable(gc))
380381
end
381382
end
383+
384+
@testset "CGConv" begin
385+
fg = FeaturedGraph(adj)
386+
X = rand(Float32, in_channel, N)
387+
E = rand(Float32, in_channel_edge, ne(fg))
388+
Xt = transpose(rand(Float32, N, in_channel))
389+
@testset "layer with graph" begin
390+
cgc = CGConv(FeaturedGraph(adj),
391+
(in_channel, in_channel_edge))
392+
@test size(cgc.Wf) == (in_channel, 2 * in_channel + in_channel_edge)
393+
@test size(cgc.Ws) == (in_channel, 2 * in_channel + in_channel_edge)
394+
@test size(cgc.bf) == (in_channel,)
395+
@test size(cgc.bs) == (in_channel,)
396+
397+
Y = cgc(X, E)
398+
@test size(Y) == (in_channel, N)
399+
400+
Yg = cgc(FeaturedGraph(adj, nf=X, ef=E))
401+
@test size(node_feature(Yg)) == (in_channel, N)
402+
@test edge_feature(Yg) == E
403+
end
404+
end
382405
end

0 commit comments

Comments
 (0)