Skip to content

Commit f1b3c1d

Browse files
authored
Merge pull request #253 from FluxML/develop
Fix message-passing tests
2 parents 16d43e5 + e8905c6 commit f1b3c1d

File tree

8 files changed

+305
-319
lines changed

8 files changed

+305
-319
lines changed

src/GeometricFlux.jl

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ using LinearAlgebra: Adjoint, norm, Transpose
77
using Random
88
using Reexport
99

10-
using CUDA
10+
using CUDA, CUDA.CUSPARSE
1111
using ChainRulesCore: @non_differentiable
1212
using FillArrays: Fill
1313
using Flux
@@ -78,8 +78,6 @@ include("layers/misc.jl")
7878
include("sampling.jl")
7979
include("embedding/node2vec.jl")
8080

81-
include("cuda/conv.jl")
82-
8381
using .Datasets
8482

8583

src/cuda/conv.jl

Lines changed: 0 additions & 26 deletions
This file was deleted.

src/layers/conv.jl

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ function (c::ChebConv)(fg::FeaturedGraph, X::AbstractMatrix{T}) where T
104104
Y = view(c.weight,:,:,1) * Z_prev
105105
Y += view(c.weight,:,:,2) * Z
106106
for k = 3:c.k
107-
Z, Z_prev = 2*Z*- Z_prev, Z
107+
Z, Z_prev = 2 .* Z * - Z_prev, Z
108108
Y += view(c.weight,:,:,k) * Z
109109
end
110110
return Y .+ c.bias
@@ -253,13 +253,14 @@ function apply_batch_message(gat::GATConv, i, js, X::AbstractMatrix)
253253
end
254254

255255
function update_batch_edge(gat::GATConv, sg::SparseGraph, E::AbstractMatrix, X::AbstractMatrix, u)
256-
@assert check_self_loops(sg) "a vertex must have self loop (receive a message from itself)."
257-
mapreduce(i -> apply_batch_message(gat, i, neighbors(sg, i), X), hcat, 1:nv(sg))
256+
@assert Zygote.ignore(() -> check_self_loops(sg)) "a vertex must have self loop (receive a message from itself)."
257+
ys = map(i -> apply_batch_message(gat, i, GraphSignals.cpu_neighbors(sg, i), X), 1:nv(sg))
258+
return hcat(ys...)
258259
end
259260

260261
function check_self_loops(sg::SparseGraph)
261262
for i in 1:nv(sg)
262-
if !(i in GraphSignals.rowvalview(sg.S, i))
263+
if !(i in collect(GraphSignals.rowvalview(sg.S, i)))
263264
return false
264265
end
265266
end

src/layers/gn.jl

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,14 +16,22 @@ abstract type GraphNet <: AbstractGraphLayer end
1616
@inline update_vertex(gn::GraphNet, ē, vi, u) = vi
1717
@inline update_global(gn::GraphNet, ē, v̄, u) = u
1818

19-
@inline update_batch_edge(gn::GraphNet, sg::SparseGraph, E, V, u) =
20-
mapreduce(i -> apply_batch_message(gn, sg, i, neighbors(sg, i), E, V, u), hcat, vertices(sg))
19+
@inline function update_batch_edge(gn::GraphNet, sg::SparseGraph, E, V, u)
20+
ys = map(i -> apply_batch_message(gn, sg, i, GraphSignals.cpu_neighbors(sg, i), E, V, u), vertices(sg))
21+
return hcat(ys...)
22+
end
2123

22-
@inline apply_batch_message(gn::GraphNet, sg::SparseGraph, i, js, E, V, u) =
23-
mapreduce(j -> update_edge(gn, _view(E, edge_index(sg, i, j)), _view(V, i), _view(V, j), u), hcat, js)
24+
@inline function apply_batch_message(gn::GraphNet, sg::SparseGraph, i, js, E, V, u)
25+
# js still CuArray
26+
es = Zygote.ignore(() -> GraphSignals.cpu_incident_edges(sg, i))
27+
ys = map(k -> update_edge(gn, _view(E, es[k]), _view(V, i), _view(V, js[k]), u), 1:length(js))
28+
return hcat(ys...)
29+
end
2430

25-
@inline update_batch_vertex(gn::GraphNet, Ē, V, u) =
26-
mapreduce(i -> update_vertex(gn, _view(Ē, i), _view(V, i), u), hcat, 1:size(V,2))
31+
@inline function update_batch_vertex(gn::GraphNet, Ē, V, u)
32+
ys = map(i -> update_vertex(gn, _view(Ē, i), _view(V, i), u), 1:size(V,2))
33+
return hcat(ys...)
34+
end
2735

2836
@inline aggregate_neighbors(gn::GraphNet, sg::SparseGraph, aggr, E) = neighbor_scatter(aggr, E, sg)
2937
@inline aggregate_neighbors(gn::GraphNet, sg::SparseGraph, aggr::Nothing, @nospecialize E) = nothing

test/cuda/conv.jl

Lines changed: 73 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -51,70 +51,77 @@
5151
@test size(g.bias) == size(cc.bias)
5252
end
5353

54-
# @testset "GraphConv" begin
55-
# gc = GraphConv(fg, in_channel=>out_channel) |> gpu
56-
# @test size(gc.weight1) == (out_channel, in_channel)
57-
# @test size(gc.weight2) == (out_channel, in_channel)
58-
# @test size(gc.bias) == (out_channel,)
59-
60-
# X = rand(in_channel, N) |> gpu
61-
# Y = gc(X)
62-
# @test size(Y) == (out_channel, N)
63-
64-
# g = Zygote.gradient(x -> sum(gc(x)), X)[1]
65-
# @test size(g) == size(X)
66-
67-
# g = Zygote.gradient(model -> sum(model(X)), gc)[1]
68-
# @test size(g.weight1) == size(gc.weight1)
69-
# @test size(g.weight2) == size(gc.weight2)
70-
# @test size(g.bias) == size(gc.bias)
71-
# end
72-
73-
# @testset "GATConv" begin
74-
# gat = GATConv(fg, in_channel=>out_channel) |> gpu
75-
# @test size(gat.weight) == (out_channel, in_channel)
76-
# @test size(gat.bias) == (out_channel,)
77-
78-
# X = rand(in_channel, N) |> gpu
79-
# Y = gat(X)
80-
# @test size(Y) == (out_channel, N)
81-
82-
# g = Zygote.gradient(x -> sum(gat(x)), X)[1]
83-
# @test size(g) == size(X)
84-
85-
# g = Zygote.gradient(model -> sum(model(X)), gat)[1]
86-
# @test size(g.weight) == size(gat.weight)
87-
# @test size(g.bias) == size(gat.bias)
88-
# @test size(g.a) == size(gat.a)
89-
# end
90-
91-
# @testset "GatedGraphConv" begin
92-
# num_layers = 3
93-
# ggc = GatedGraphConv(fg, out_channel, num_layers) |> gpu
94-
# @test size(ggc.weight) == (out_channel, out_channel, num_layers)
95-
96-
# X = rand(in_channel, N) |> gpu
97-
# Y = ggc(X)
98-
# @test size(Y) == (out_channel, N)
99-
100-
# g = Zygote.gradient(x -> sum(ggc(x)), X)[1]
101-
# @test size(g) == size(X)
102-
103-
# g = Zygote.gradient(model -> sum(model(X)), ggc)[1]
104-
# @test size(g.weight) == size(ggc.weight)
105-
# end
106-
107-
# @testset "EdgeConv" begin
108-
# ec = EdgeConv(fg, Dense(2*in_channel, out_channel)) |> gpu
109-
# X = rand(in_channel, N) |> gpu
110-
# Y = ec(X)
111-
# @test size(Y) == (out_channel, N)
112-
113-
# g = Zygote.gradient(x -> sum(ec(x)), X)[1]
114-
# @test size(g) == size(X)
115-
116-
# g = Zygote.gradient(model -> sum(model(X)), ec)[1]
117-
# @test size(g.nn.weight) == size(ec.nn.weight)
118-
# @test size(g.nn.bias) == size(ec.nn.bias)
119-
# end
54+
@testset "GraphConv" begin
55+
gc = GraphConv(fg, in_channel=>out_channel) |> gpu
56+
@test size(gc.weight1) == (out_channel, in_channel)
57+
@test size(gc.weight2) == (out_channel, in_channel)
58+
@test size(gc.bias) == (out_channel,)
59+
60+
X = rand(in_channel, N) |> gpu
61+
Y = gc(X)
62+
@test size(Y) == (out_channel, N)
63+
64+
g = Zygote.gradient(x -> sum(gc(x)), X)[1]
65+
@test size(g) == size(X)
66+
67+
g = Zygote.gradient(model -> sum(model(X)), gc)[1]
68+
@test size(g.weight1) == size(gc.weight1)
69+
@test size(g.weight2) == size(gc.weight2)
70+
@test size(g.bias) == size(gc.bias)
71+
end
72+
73+
@testset "GATConv" begin
74+
adj = T[1 1 0 1;
75+
1 1 1 0;
76+
0 1 1 1;
77+
1 0 1 1]
78+
79+
fg = FeaturedGraph(adj)
80+
81+
gat = GATConv(fg, in_channel=>out_channel) |> gpu
82+
@test size(gat.weight) == (out_channel, in_channel)
83+
@test size(gat.bias) == (out_channel,)
84+
85+
X = rand(in_channel, N) |> gpu
86+
Y = gat(X)
87+
@test size(Y) == (out_channel, N)
88+
89+
g = Zygote.gradient(x -> sum(gat(x)), X)[1]
90+
@test size(g) == size(X)
91+
92+
g = Zygote.gradient(model -> sum(model(X)), gat)[1]
93+
@test size(g.weight) == size(gat.weight)
94+
@test size(g.bias) == size(gat.bias)
95+
@test size(g.a) == size(gat.a)
96+
end
97+
98+
@testset "GatedGraphConv" begin
99+
num_layers = 3
100+
ggc = GatedGraphConv(fg, out_channel, num_layers) |> gpu
101+
@test size(ggc.weight) == (out_channel, out_channel, num_layers)
102+
103+
X = rand(in_channel, N) |> gpu
104+
Y = ggc(X)
105+
@test size(Y) == (out_channel, N)
106+
107+
g = Zygote.gradient(x -> sum(ggc(x)), X)[1]
108+
@test size(g) == size(X)
109+
110+
g = Zygote.gradient(model -> sum(model(X)), ggc)[1]
111+
@test size(g.weight) == size(ggc.weight)
112+
end
113+
114+
@testset "EdgeConv" begin
115+
ec = EdgeConv(fg, Dense(2*in_channel, out_channel)) |> gpu
116+
X = rand(in_channel, N) |> gpu
117+
Y = ec(X)
118+
@test size(Y) == (out_channel, N)
119+
120+
g = Zygote.gradient(x -> sum(ec(x)), X)[1]
121+
@test size(g) == size(X)
122+
123+
g = Zygote.gradient(model -> sum(model(X)), ec)[1]
124+
@test size(g.nn.weight) == size(ec.nn.weight)
125+
@test size(g.nn.bias) == size(ec.nn.bias)
126+
end
120127
end

0 commit comments

Comments
 (0)