Skip to content

Commit c380eed

Browse files
committed
drop not used functions and refactor
1 parent 5a8ea5c commit c380eed

File tree

3 files changed

+23
-31
lines changed

3 files changed

+23
-31
lines changed

src/layers/gn.jl

Lines changed: 19 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,7 @@
11
_view(::Nothing, idx) = nothing
22
_view(A::Fill{T,2,Axes}, idx) where {T,Axes} = fill(A.value, A.axes[1], length(idx))
3-
4-
function _view(A::SubArray{T,2,S}, idx) where {T,S<:Fill}
5-
p = parent(A)
6-
return Fill(p.value, p.axes[1].stop, length(idx))
7-
end
8-
93
_view(A::AbstractMatrix, idx) = view(A, :, idx)
104

11-
function _view(A::SubArray{T,2,S}, idxs) where {T,S<:AbstractMatrix}
12-
view_idx = A.indices[2]
13-
if view_idx == idxs
14-
return A
15-
else
16-
idxs = findall(x -> x in idxs, view_idx)
17-
return view(A, :, idxs)
18-
end
19-
end
20-
215
aggregate(aggr::typeof(+), X) = vec(sum(X, dims=2))
226
aggregate(aggr::typeof(-), X) = -vec(sum(X, dims=2))
237
aggregate(aggr::typeof(*), X) = vec(prod(X, dims=2))
@@ -32,10 +16,16 @@ abstract type GraphNet <: AbstractGraphLayer end
3216
@inline update_vertex(gn::GraphNet, ē, vi, u) = vi
3317
@inline update_global(gn::GraphNet, ē, v̄, u) = u
3418

19+
function _get_indices(fg::AbstractFeaturedGraph)
20+
es = cpu(GraphSignals.incident_edges(fg))
21+
xs = cpu(GraphSignals.repeat_nodes(fg))
22+
nbrs = cpu(GraphSignals.neighbors(fg))
23+
sorted_idx = sort!(collect(zip(es, xs, nbrs)), by=x->x[1])
24+
return collect.(collect(zip(sorted_idx...)))
25+
end
26+
3527
@inline function update_batch_edge(gn::GraphNet, fg::AbstractFeaturedGraph, E, V, u)
36-
es = Zygote.ignore(()->cpu(GraphSignals.incident_edges(fg)))
37-
xs = Zygote.ignore(()->cpu(GraphSignals.repeat_nodes(fg)))
38-
nbrs = Zygote.ignore(()->cpu(GraphSignals.neighbors(fg)))
28+
es, xs, nbrs = Zygote.ignore(()->_get_indices(fg))
3929
ms = map((e,i,j)->update_edge(gn, _view(E, e), _view(V, i), _view(V, j), u), es, xs, nbrs)
4030
M = hcat_by_sum(ms)
4131
return M
@@ -50,7 +40,7 @@ end
5040

5141
@inline function aggregate_neighbors(gn::GraphNet, fg::AbstractFeaturedGraph, aggr, E)
5242
N = nv(parent(fg))
53-
xs = Zygote.ignore(()->cpu(GraphSignals.repeat_nodes(fg)))
43+
es, xs, nbrs = Zygote.ignore(()->_get_indices(fg))
5444
= NNlib.scatter(aggr, E, xs; dstsize=(size(E, 1), N))
5545
return
5646
end
@@ -67,6 +57,14 @@ function propagate(gn::GraphNet, fg::AbstractFeaturedGraph, naggr=nothing, eaggr
6757
FeaturedGraph(fg, nf=V, ef=E, gf=u)
6858
end
6959

60+
"""
61+
- `update_batch_edge`: (E_in_dim, E) -> (E_out_dim, E)
62+
- `aggregate_neighbors`: (E_out_dim, E) -> (E_out_dim, V)
63+
- `update_batch_vertex`: (V_in_dim, V) -> (V_out_dim, V)
64+
- `aggregate_edges`: (E_out_dim, E) -> (E_out_dim,)
65+
- `aggregate_vertices`: (V_out_dim, V) -> (V_out_dim,)
66+
- `update_global`: (dim,) -> (dim,)
67+
"""
7068
function propagate(gn::GraphNet, fg::AbstractFeaturedGraph, E::AbstractArray, V::AbstractArray, u::AbstractArray,
7169
naggr=nothing, eaggr=nothing, vaggr=nothing)
7270
E = update_batch_edge(gn, fg, E, V, u)
@@ -75,5 +73,5 @@ function propagate(gn::GraphNet, fg::AbstractFeaturedGraph, E::AbstractArray, V:
7573
= aggregate_edges(gn, eaggr, E)
7674
= aggregate_vertices(gn, vaggr, V)
7775
u = update_global(gn, ē, v̄, u)
78-
return parent(E), parent(V), u
76+
return E, V, u
7977
end

src/utils.jl

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,3 @@ function ChainRulesCore.rrule(::typeof(hcat_by_sum), xs::AbstractVector)
2424
hcat_by_sum_pullback(Δ) = (NoTangent(), ntuple(i->view(Δ,:,ns[i]:(ns[i+1]-1)), N))
2525
hcat_by_sum(xs), hcat_by_sum_pullback
2626
end
27-
28-
function ChainRulesCore.rrule(::typeof(parent), A::Base.SubArray)
29-
parent_pullback(Δ) = (NoTangent(), view(Δ, A.indices...))
30-
parent(A), parent_pullback
31-
end

test/layers/gn.jl

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@
55
V = 6
66
E = 7
77

8-
nf = rand(T, in_channel, V)
9-
ef = rand(T, in_channel, E)
8+
nf = hcat([repeat(T[i], in_channel) for i in 1:V]...)
9+
ef = hcat([repeat(T[i], in_channel) for i in 1:E]...)
1010
gf = rand(T, in_channel)
1111

1212
adj = T[0. 1. 0. 0. 0. 0.;
@@ -17,18 +17,17 @@
1717
0. 1. 1. 0. 1. 0.]
1818

1919
struct NewGNLayer <: GraphNet end
20-
21-
l = NewGNLayer()
2220

2321
@testset "without aggregation" begin
2422
function (l::NewGNLayer)(fg::FeaturedGraph)
2523
GeometricFlux.propagate(l, fg, edge_feature(fg), node_feature(fg), global_feature(fg))
2624
end
2725

2826
fg = FeaturedGraph(adj, nf=nf)
27+
l = NewGNLayer()
2928
ef_, nf_, gf_ = l(fg)
3029

31-
@test size(nf_) == (in_channel, V)
30+
@test nf_ == nf
3231
@test size(ef_) == (0, 2E)
3332
@test size(gf_) == (0,)
3433
end

0 commit comments

Comments
 (0)