Skip to content

Commit e8905c6

Browse files
committed
fix GATConv layer
1 parent dfe3381 commit e8905c6

File tree

4 files changed

+10
-8
lines changed

4 files changed

+10
-8
lines changed

src/GeometricFlux.jl

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -78,8 +78,6 @@ include("layers/misc.jl")
7878
include("sampling.jl")
7979
include("embedding/node2vec.jl")
8080

81-
include("cuda/conv.jl")
82-
8381
using .Datasets
8482

8583

src/cuda/conv.jl

Lines changed: 0 additions & 4 deletions
This file was deleted.

src/layers/conv.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -253,8 +253,9 @@ function apply_batch_message(gat::GATConv, i, js, X::AbstractMatrix)
253253
end
254254

255255
function update_batch_edge(gat::GATConv, sg::SparseGraph, E::AbstractMatrix, X::AbstractMatrix, u)
256-
@assert check_self_loops(sg) "a vertex must have self loop (receive a message from itself)."
257-
mapreduce(i -> apply_batch_message(gat, i, neighbors(sg, i), X), hcat, 1:nv(sg))
256+
@assert Zygote.ignore(() -> check_self_loops(sg)) "a vertex must have self loop (receive a message from itself)."
257+
ys = map(i -> apply_batch_message(gat, i, GraphSignals.cpu_neighbors(sg, i), X), 1:nv(sg))
258+
return hcat(ys...)
258259
end
259260

260261
function check_self_loops(sg::SparseGraph)

test/cuda/conv.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,13 @@
7171
end
7272

7373
@testset "GATConv" begin
74+
adj = T[1 1 0 1;
75+
1 1 1 0;
76+
0 1 1 1;
77+
1 0 1 1]
78+
79+
fg = FeaturedGraph(adj)
80+
7481
gat = GATConv(fg, in_channel=>out_channel) |> gpu
7582
@test size(gat.weight) == (out_channel, in_channel)
7683
@test size(gat.bias) == (out_channel,)

0 commit comments

Comments
 (0)