|
436 | 436 |
|
437 | 437 | (l::GINConv)(x::AbstractMatrix) = l(l.fg, x) |
438 | 438 | (l::GINConv)(fg::FeaturedGraph) = FeaturedGraph(fg.graph, nf = l(fg, node_feature(fg))) |
| 439 | + |
| 440 | + |
| 441 | +""" |
| 442 | + CGConv([fg,] (node_dim, edge_dim), out, init) |
| 443 | +
|
| 444 | +Crystal Graph Convolutional network. Uses both node and edge features. |
| 445 | +
|
| 446 | +# Arguments |
| 447 | +
|
| 448 | +- `fg`: Optional [`FeaturedGraph`] argument(@ref) |
| 449 | +- `node_dim`: Dimensionality of the input node features. Also is necessarily the output dimensionality. |
| 450 | +- `edge_dim`: Dimensionality of the input edge features. |
| 451 | +- `out`: Dimensionality of the output features. |
| 452 | +- `init`: Initialization algorithm for each of the weight matrices |
| 453 | +- `bias`: Whether or not to learn an additive bias parameter. |
| 454 | +
|
| 455 | +# Usage |
| 456 | +
|
| 457 | +You can call `CGConv` in several different ways: |
| 458 | + |
| 459 | +- Pass a FeaturedGraph: `CGConv(fg)`, returns `FeaturedGraph` |
| 460 | +- Pass both node and edge features: `CGConv(X, E)` |
| 461 | +- Pass one matrix, which can either be node features or edge features: `CGConv(M; edge)`: |
| 462 | + `edge` is default false, meaning that `M` denotes node features. |
| 463 | +""" |
| 464 | +struct CGConv{V <: AbstractFeaturedGraph, T, |
| 465 | + A <: AbstractMatrix{T}, B} <: MessagePassing |
| 466 | + fg::V |
| 467 | + Wf::A |
| 468 | + Ws::A |
| 469 | + bf::B |
| 470 | + bs::B |
| 471 | +end |
| 472 | + |
| 473 | +@functor CGConv |
| 474 | + |
| 475 | +function CGConv(fg::AbstractFeaturedGraph, dims::NTuple{2,Int}; |
| 476 | + init=glorot_uniform, bias=true) |
| 477 | + node_dim, edge_dim = dims |
| 478 | + Wf = init(node_dim, 2*node_dim + edge_dim) |
| 479 | + Ws = init(node_dim, 2*node_dim + edge_dim) |
| 480 | + bf = Flux.create_bias(Wf, bias, node_dim) |
| 481 | + bs = Flux.create_bias(Ws, bias, node_dim) |
| 482 | + CGConv(fg, Wf, Ws, bf, bs) |
| 483 | +end |
| 484 | + |
| 485 | +function CGConv(dims::NTuple{2,Int}; init=glorot_uniform, bias=true) |
| 486 | + CGConv(NullGraph(), dims; init=init, bias=bias) |
| 487 | +end |
| 488 | + |
| 489 | +message(c::CGConv, |
| 490 | + x_i::AbstractVector, x_j::AbstractVector, e::AbstractVector) = begin |
| 491 | + z = vcat(x_i, x_j, e) |
| 492 | + σ.(c.Wf * z + c.bf) .* softplus.(c.Ws * z + c.bs) |
| 493 | +end |
| 494 | +update(c::CGConv, m::AbstractVector, x) = x + m |
| 495 | + |
| 496 | +function (c::CGConv)(fg::FeaturedGraph, X::AbstractMatrix, E::AbstractMatrix) |
| 497 | + check_num_nodes(fg, X) |
| 498 | + check_num_edges(fg, E) |
| 499 | + _, Y = propagate(c, adjacency_list(fg), E, X, +) |
| 500 | + Y |
| 501 | +end |
| 502 | + |
| 503 | +(l::CGConv)(fg::FeaturedGraph) = FeaturedGraph(fg, nf=l(fg, node_feature(fg), |
| 504 | + edge_feature(fg)), |
| 505 | + ef=edge_feature(fg)) |
| 506 | +(l::CGConv)(M::AbstractMatrix; as_edge=false) = |
| 507 | + if as_edge |
| 508 | + l(l.fg, node_feature(l.fg), M) |
| 509 | + else |
| 510 | + l(l.fg, M, edge_feature(l.fg)) |
| 511 | + end |
| 512 | +(l::CGConv)(X::AbstractMatrix, E::AbstractMatrix) = l(l.fg, X, E) |
0 commit comments