Skip to content

Commit 995b4d8

Browse files
committed
Replace bypass_graph as Bypass layer
1 parent c380eed commit 995b4d8

File tree

3 files changed

+46
-28
lines changed

3 files changed

+46
-28
lines changed

src/GeometricFlux.jl

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -54,11 +54,8 @@ export
5454
InnerProductDecoder,
5555
VariationalEncoder,
5656

57-
# layer/selector
58-
bypass_graph,
59-
60-
# utils
61-
generate_cluster,
57+
# layer/misc
58+
Bypass,
6259

6360
#node2vec
6461
node2vec

src/layers/misc.jl

Lines changed: 29 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,32 @@
11
"""
2-
bypass_graph(nf_func, ef_func, gf_func)
3-
4-
Bypassing graph in FeaturedGraph and let other layer process (node, edge and global)features only.
2+
Bypassing graph in FeaturedGraph and let other layer process (node, edge and global) features only.
53
"""
6-
function bypass_graph(nf_func=identity, ef_func=identity, gf_func=identity)
7-
return function (fg::FeaturedGraph)
8-
FeaturedGraph(fg,
9-
nf=nf_func(node_feature(fg)),
10-
ef=ef_func(edge_feature(fg)),
11-
gf=gf_func(global_feature(fg)))
12-
end
4+
struct Bypass{N,E,G}
5+
node_layer::N
6+
edge_layer::E
7+
global_layer::G
8+
end
9+
10+
@functor Bypass
11+
12+
Bypass(; node_layer=identity, edge_layer=identity, global_layer=identity) =
13+
Bypass(node_layer, edge_layer, global_layer)
14+
15+
function (l::Bypass)(fg::FeaturedGraph)
16+
nf = l.node_layer(node_feature(fg))
17+
ef = l.edge_layer(edge_feature(fg))
18+
gf = l.global_layer(global_feature(fg))
19+
return FeaturedGraph(fg, nf=nf, ef=ef, gf=gf)
20+
end
21+
22+
function (l::Bypass)(fsg::FeaturedSubgraph)
23+
nf = l.node_layer(node_feature(fsg))
24+
ef = l.edge_layer(edge_feature(fsg))
25+
gf = l.global_layer(global_feature(fsg))
26+
fg = parent(fsg)
27+
vidx = fsg.nodes
28+
nf = NNlib.scatter(+, nf, vidx; init=0, dstsize=(size(nf,1), nv(fg)))
29+
ef = NNlib.scatter(+, ef, edges(fsg); init=0, dstsize=(size(ef,1), ne(fg)))
30+
fg = FeaturedGraph(fg, nf=nf, ef=ef, gf=gf)
31+
return subgraph(fg, vidx)
1332
end

test/layers/misc.jl

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,26 @@
11
@testset "misc" begin
2-
@testset "bypass_graph" begin
2+
@testset "Bypass" begin
3+
T = Float32
34
N = 4
45
E = 5
5-
adj = [0 1 1 1;
6-
1 0 1 0;
7-
1 1 0 1;
8-
1 0 1 0]
6+
adj = T[0 1 1 1;
7+
1 0 1 0;
8+
1 1 0 1;
9+
1 0 1 0]
910

1011
nf = rand(3, N)
1112
ef = rand(5, E)
1213
gf = rand(7)
14+
nodes = [1,2,4]
1315

1416
fg = FeaturedGraph(adj, nf=nf, ef=ef, gf=gf)
15-
layer = bypass_graph(x -> x .+ 1.,
16-
x -> x .+ 2.,
17-
x -> x .+ 3.)
18-
fg_ = layer(fg)
19-
@test GraphSignals.adjacency_matrix(fg_) == adj
20-
@test node_feature(fg_) == nf .+ 1.
21-
@test edge_feature(fg_) == ef .+ 2.
22-
@test global_feature(fg_) == gf .+ 3.
17+
fsg = subgraph(fg, nodes)
18+
19+
layer = Bypass(node_layer=Dropout(0.5),
20+
global_layer=x -> x .+ 3.)
21+
fsg_ = layer(fsg)
22+
@test node_feature(fsg_) == view(nf, :, nodes)
23+
@test edge_feature(fsg_) == view(ef, :, edges(fsg))
24+
@test global_feature(fsg_) == gf .+ 3.
2325
end
2426
end

0 commit comments

Comments
 (0)