@@ -16,7 +16,7 @@ Graph convolutional layer.
1616The input to the layer is a node feature array `X`
1717of size `(num_features, num_nodes)`.
1818"""
19- struct GCNConv{A<: AbstractMatrix , B, F, S<: AbstractFeaturedGraph }
19+ struct GCNConv{A<: AbstractMatrix , B, F, S<: AbstractFeaturedGraph } <: AbstractGraphLayer
2020 weight:: A
2121 bias:: B
2222 σ:: F
@@ -42,7 +42,6 @@ function (l::GCNConv)(fg::FeaturedGraph, x::AbstractMatrix)
4242end
4343
4444(l:: GCNConv )(fg:: FeaturedGraph ) = FeaturedGraph (fg, nf = l (fg, node_feature (fg)))
45- (l:: GCNConv )(x:: AbstractMatrix ) = l (l. fg, x)
4645
4746function Base. show (io:: IO , l:: GCNConv )
4847 out, in = size (l. weight)
@@ -66,7 +65,7 @@ Chebyshev spectral graph convolutional layer.
6665- `bias`: Add learnable bias.
6766- `init`: Weights' initializer.
6867"""
69- struct ChebConv{A<: AbstractArray{<:Number,3} , B, S<: AbstractFeaturedGraph }
68+ struct ChebConv{A<: AbstractArray{<:Number,3} , B, S<: AbstractFeaturedGraph } <: AbstractGraphLayer
7069 weight:: A
7170 bias:: B
7271 fg:: S
@@ -104,7 +103,6 @@ function (c::ChebConv)(fg::FeaturedGraph, X::AbstractMatrix{T}) where T
104103end
105104
106105(l:: ChebConv )(fg:: FeaturedGraph ) = FeaturedGraph (fg, nf = l (fg, node_feature (fg)))
107- (l:: ChebConv )(x:: AbstractMatrix ) = l (l. fg, x)
108106
109107function Base. show (io:: IO , l:: ChebConv )
110108 out, in, k = size (l. weight)
@@ -164,7 +162,6 @@ function (gc::GraphConv)(fg::FeaturedGraph, x::AbstractMatrix)
164162end
165163
166164(l:: GraphConv )(fg:: FeaturedGraph ) = FeaturedGraph (fg, nf = l (fg, node_feature (fg)))
167- (l:: GraphConv )(x:: AbstractMatrix ) = l (l. fg, x)
168165
169166function Base. show (io:: IO , l:: GraphConv )
170167 in_channel = size (l. weight1, ndims (l. weight1))
@@ -272,7 +269,6 @@ function (gat::GATConv)(fg::FeaturedGraph, X::AbstractMatrix)
272269end
273270
274271(l:: GATConv )(fg:: FeaturedGraph ) = FeaturedGraph (fg, nf = l (fg, node_feature (fg)))
275- (l:: GATConv )(x:: AbstractMatrix ) = l (l. fg, x)
276272
277273function Base. show (io:: IO , l:: GATConv )
278274 in_channel = size (l. weight, ndims (l. weight))
@@ -340,7 +336,6 @@ function (ggc::GatedGraphConv)(fg::FeaturedGraph, H::AbstractMatrix{S}) where {T
340336end
341337
342338(l:: GatedGraphConv )(fg:: FeaturedGraph ) = FeaturedGraph (fg, nf = l (fg, node_feature (fg)))
343- (l:: GatedGraphConv )(x:: AbstractMatrix ) = l (l. fg, x)
344339
345340
346341function Base. show (io:: IO , l:: GatedGraphConv )
@@ -383,7 +378,6 @@ function (ec::EdgeConv)(fg::FeaturedGraph, X::AbstractMatrix)
383378end
384379
385380(l:: EdgeConv )(fg:: FeaturedGraph ) = FeaturedGraph (fg, nf = l (fg, node_feature (fg)))
386- (l:: EdgeConv )(x:: AbstractMatrix ) = l (l. fg, x)
387381
388382function Base. show (io:: IO , l:: EdgeConv )
389383 print (io, " EdgeConv(" , l. nn)
@@ -393,34 +387,34 @@ end
393387
394388
395389"""
396- GINConv([fg,] nn, [eps])
390+ GINConv([fg,] nn, [eps=0 ])
397391
398392 Graph Isomorphism Network.
399393
400394# Arguments
401395
402396- `fg`: Optionally pass in a FeaturedGraph as input.
403397- `nn`: A neural network/layer.
404- - `eps`: Weighting factor. Default 0.
398+ - `eps`: Weighting factor.
405399
406400The definition of this is as defined in the original paper,
407401Xu et. al. (2018) https://arxiv.org/abs/1810.00826.
408402"""
409- struct GINConv{V <: AbstractFeaturedGraph ,R <: Real } <: MessagePassing
410- fg:: V
403+ struct GINConv{G,R } <: MessagePassing
404+ fg:: G
411405 nn
412406 eps:: R
413- end
414407
415- function GINConv (fg:: AbstractFeaturedGraph , nn; eps= 0f0 )
416- GINConv (fg, nn, eps)
408+ function GINConv (fg:: G , nn, eps:: R = 0f0 ) where {G<: AbstractFeaturedGraph ,R<: Real }
409+ new {G,R} (fg, nn, eps)
410+ end
417411end
418412
419- function GINConv (nn; eps= 0f0 )
413+ function GINConv (nn, eps:: Real = 0f0 )
420414 GINConv (NullGraph (), nn, eps)
421415end
422416
423- Flux. trainable (g:: GINConv ) = (fg= g. fg,nn= g. nn)
417+ Flux. trainable (g:: GINConv ) = (fg= g. fg, nn= g. nn)
424418
425419message (g:: GINConv , x_i:: AbstractVector , x_j:: AbstractVector ) = x_j
426420update (g:: GINConv , m:: AbstractVector , x) = g. nn ((1 + g. eps) * x + m)
@@ -434,12 +428,11 @@ function (g::GINConv)(fg::FeaturedGraph, X::AbstractMatrix)
434428 X
435429end
436430
437- (l:: GINConv )(x:: AbstractMatrix ) = l (l. fg, x)
438431(l:: GINConv )(fg:: FeaturedGraph ) = FeaturedGraph (fg. graph, nf = l (fg, node_feature (fg)))
439432
440433
441434"""
442- CGConv([fg,] (node_dim, edge_dim), out, init)
435+ CGConv([fg,] (node_dim, edge_dim), out, init, bias=true, as_edge=false )
443436
444437Crystal Graph Convolutional network. Uses both node and edge features.
445438
@@ -451,18 +444,17 @@ Crystal Graph Convolutional network. Uses both node and edge features.
451444- `out`: Dimensionality of the output features.
452445- `init`: Initialization algorithm for each of the weight matrices
453446- `bias`: Whether or not to learn an additive bias parameter.
447+ - `as_edge`: When call to layer `CGConv(M)`, accept input feature as node features or edge features.
454448
455449# Usage
456450
457451You can call `CGConv` in several different ways:
458452
459453- Pass a FeaturedGraph: `CGConv(fg)`, returns `FeaturedGraph`
460454- Pass both node and edge features: `CGConv(X, E)`
461- - Pass one matrix, which can either be node features or edge features: `CGConv(M; edge)`:
462- `edge` is default false, meaning that `M` denotes node features.
455+ - Pass one matrix, which is determined as node features or edge features by `as_edge` keyword argument.
463456"""
464- struct CGConv{V <: AbstractFeaturedGraph , T,
465- A <: AbstractMatrix{T} , B} <: MessagePassing
457+ struct CGConv{E, V<: AbstractFeaturedGraph , A<: AbstractMatrix , B} <: MessagePassing
466458 fg:: V
467459 Wf:: A
468460 Ws:: A
@@ -472,18 +464,20 @@ end
472464
473465@functor CGConv
474466
475- function CGConv (fg:: AbstractFeaturedGraph , dims:: NTuple{2,Int} ;
476- init= glorot_uniform, bias= true )
467+ function CGConv (fg:: G , dims:: NTuple{2,Int} ;
468+ init= glorot_uniform, bias= true , as_edge = false ) where {G <: AbstractFeaturedGraph }
477469 node_dim, edge_dim = dims
478470 Wf = init (node_dim, 2 * node_dim + edge_dim)
479471 Ws = init (node_dim, 2 * node_dim + edge_dim)
480472 bf = Flux. create_bias (Wf, bias, node_dim)
481473 bs = Flux. create_bias (Ws, bias, node_dim)
482- CGConv (fg, Wf, Ws, bf, bs)
474+ T, S = typeof (Wf), typeof (bf)
475+
476+ CGConv {as_edge,G,T,S} (fg, Wf, Ws, bf, bs)
483477end
484478
485- function CGConv (dims:: NTuple{2,Int} ; init= glorot_uniform, bias= true )
486- CGConv (NullGraph (), dims; init= init, bias= bias)
479+ function CGConv (dims:: NTuple{2,Int} ; init= glorot_uniform, bias= true , as_edge = false )
480+ CGConv (NullGraph (), dims; init= init, bias= bias, as_edge = as_edge )
487481end
488482
489483message (c:: CGConv ,
503497(l:: CGConv )(fg:: FeaturedGraph ) = FeaturedGraph (fg, nf= l (fg, node_feature (fg),
504498 edge_feature (fg)),
505499 ef= edge_feature (fg))
506- (l:: CGConv )(M:: AbstractMatrix ; as_edge= false ) =
507- if as_edge
508- l (l. fg, node_feature (l. fg), M)
509- else
510- l (l. fg, M, edge_feature (l. fg))
511- end
500+
512501(l:: CGConv )(X:: AbstractMatrix , E:: AbstractMatrix ) = l (l. fg, X, E)
502+
503+ (l:: CGConv{true} )(M:: AbstractMatrix ) = l (l. fg, node_feature (l. fg), M)
504+ (l:: CGConv{false} )(M:: AbstractMatrix ) = l (l. fg, M, edge_feature (l. fg))
0 commit comments