Skip to content

Commit 3b24f6e

Browse files
committed
refactor
fix example and deps reorder using use bib for doc
1 parent d87c0bf commit 3b24f6e

File tree

12 files changed

+94
-140
lines changed

12 files changed

+94
-140
lines changed

Project.toml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@ version = "0.8.0"
66
[deps]
77
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
88
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
9-
Clustering = "aaaa29a8-35af-508c-8bc3-b662a17a0fe5"
109
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
1110
DelimitedFiles = "8bb1440f-4735-579b-a4ab-409b98df4dab"
1211
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
@@ -35,12 +34,14 @@ Graphs = "1.4"
3534
NNlib = "0.7"
3635
NNlibCUDA = "0.1"
3736
Reexport = "1.1"
37+
Word2Vec = "0.5"
3838
Zygote = "0.6"
3939
julia = "1.6 - 1.7"
4040

4141
[extras]
42+
Clustering = "aaaa29a8-35af-508c-8bc3-b662a17a0fe5"
4243
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
4344
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
4445

4546
[targets]
46-
test = ["SparseArrays", "Test"]
47+
test = ["Clustering", "SparseArrays", "Test"]

docs/Project.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
[deps]
22
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
3+
DocumenterCitations = "daee34ce-89f3-4625-b898-19384cb65244"
34

45
[compat]
5-
Documenter = "0.24"
6+
Documenter = "0.27"

docs/bibliography.bib

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
@inproceedings{node2vec2016,
2+
author = {Grover, Aditya and Leskovec, Jure},
3+
title = {Node2vec: Scalable Feature Learning for Networks},
4+
year = {2016},
5+
isbn = {9781450342322},
6+
publisher = {Association for Computing Machinery},
7+
address = {New York, NY, USA},
8+
url = {https://doi.org/10.1145/2939672.2939754},
9+
doi = {10.1145/2939672.2939754},
10+
abstract = {Prediction tasks over nodes and edges in networks require careful effort in engineering features used by learning algorithms. Recent research in the broader field of representation learning has led to significant progress in automating prediction by learning the features themselves. However, present feature learning approaches are not expressive enough to capture the diversity of connectivity patterns observed in networks. Here we propose node2vec, an algorithmic framework for learning continuous feature representations for nodes in networks. In node2vec, we learn a mapping of nodes to a low-dimensional space of features that maximizes the likelihood of preserving network neighborhoods of nodes. We define a flexible notion of a node's network neighborhood and design a biased random walk procedure, which efficiently explores diverse neighborhoods. Our algorithm generalizes prior work which is based on rigid notions of network neighborhoods, and we argue that the added flexibility in exploring neighborhoods is the key to learning richer representations.We demonstrate the efficacy of node2vec over existing state-of-the-art techniques on multi-label classification and link prediction in several real-world networks from diverse domains. Taken together, our work represents a new way for efficiently learning state-of-the-art task-independent representations in complex networks.},
11+
booktitle = {Proceedings of the 22nd ACM SIGKDD International Conference on Knowledge Discovery and Data Mining},
12+
pages = {855–864},
13+
numpages = {10},
14+
keywords = {node embeddings, information networks, graph representations, feature learning},
15+
location = {San Francisco, California, USA},
16+
series = {KDD '16}
17+
}

docs/make.jl

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,11 @@
11
using Documenter
2+
using DocumenterCitations
23
using GeometricFlux
34

5+
bib = CitationBibliography(joinpath(@__DIR__, "bibliography.bib"), sorting=:nyt)
6+
47
makedocs(
8+
bib,
59
sitename = "GeometricFlux.jl",
610
format = Documenter.HTML(
711
assets = ["assets/flux.css"],
@@ -24,7 +28,8 @@ makedocs(
2428
["Convolutional Layers" => "manual/conv.md",
2529
"Pooling Layers" => "manual/pool.md",
2630
"Models" => "manual/models.md",
27-
"Linear Algebra" => "manual/linalg.md"]
31+
"Linear Algebra" => "manual/linalg.md"],
32+
"References" => "references.md",
2833
]
2934
)
3035

docs/src/references.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
# References
2+
3+
```@bibliography
4+
```

src/GeometricFlux.jl

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,22 @@
11
module GeometricFlux
22

3+
using DelimitedFiles
4+
using SparseArrays
35
using Statistics: mean
46
using LinearAlgebra: Adjoint, norm, Transpose
7+
using Random
58
using Reexport
69

710
using CUDA
811
using ChainRulesCore: @non_differentiable
912
using FillArrays: Fill
1013
using Flux
1114
using Flux: glorot_uniform, leakyrelu, GRUCell, @functor
12-
using NNlib, NNlibCUDA
1315
@reexport using GraphSignals
1416
using Graphs
15-
using Random
17+
using NNlib, NNlibCUDA
1618
using Zygote
17-
using SparseArrays
18-
using DelimitedFiles
1919

20-
import Graphs: neighbors, is_directed, has_edge
2120
import Word2Vec: word2vec, wordvectors, get_vector
2221

2322
export
@@ -76,8 +75,8 @@ include("layers/pool.jl")
7675
include("models.jl")
7776
include("layers/misc.jl")
7877

79-
include("graph_embedding/sampling.jl")
80-
include("graph_embedding/node2vec.jl")
78+
include("sampling.jl")
79+
include("embedding/node2vec.jl")
8180

8281
include("cuda/conv.jl")
8382

src/graph_embedding/node2vec.jl renamed to src/embedding/node2vec.jl

Lines changed: 33 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -3,23 +3,19 @@ const Alias = Tuple{SparseVector{Int}, SparseVector{Float64}}
33
"""
44
node2vec(g; walks_per_node, len, p, q, dims)
55
6-
Computes node embeddings on graph `g`, as per [1]. Performs biased random walks on the graph,
7-
then computes word embeddings by treating those random walks like sentences.
8-
9-
Returns an nv(g) x dims matrix of embeddings
6+
Returns an embedding matrix with size of `nv(g)` x `dims`. It computes node embeddings
7+
on graph `g` accroding to node2vec [node2vec2016](@cite). It performs biased random walks on the graph,
8+
then computes word embeddings by treating those random walks as sentences.
109
1110
# Arguments
1211
1312
- `g::FeaturedGraph`: The graph to perform random walk on.
1413
- `walks_per_node::Int`: Number of walks starting on each node,
1514
total number of walks is `nv(g) * walks_per_node`
1615
- `len::Int`: Length of random walks
17-
- `p::Real`: Return parameter from [1]
18-
- `q::Real`: In-out parameter from [1]
16+
- `p::Real`: Return parameter from [node2vec2016](@cite)
17+
- `q::Real`: In-out parameter from [node2vec2016](@cite)
1918
- `dims::Int`: Number of vector dimensions
20-
21-
22-
[1] https://cs.stanford.edu/~jure/pubs/node2vec-kdd16.pdf
2319
"""
2420
function node2vec(g::FeaturedGraph; walks_per_node::Int=100, len::Int=5, p::Real=0.5, q::Real=0.5, dims::Int=128)
2521
walks = simulate_walks(g; walks_per_node=walks_per_node, len=len, p=p, q=q)
@@ -44,8 +40,7 @@ embeddings using node ID as words.
4440
[2] https://github.com/JuliaText/Word2Vec.jl
4541
[3] https://code.google.com/archive/p/word2vec/
4642
"""
47-
function walks2vec(walks::Vector{Vector{Int}};dims::Int=100)
48-
43+
function walks2vec(walks::Vector{Vector{Int}}; dims::Int=100)
4944
str_walks=map(x -> string.(x),walks)
5045

5146
if Sys.iswindows()
@@ -56,7 +51,6 @@ function walks2vec(walks::Vector{Vector{Int}};dims::Int=100)
5651
the_walks = joinpath(rpath,"str_walk.txt")
5752
the_vecs = joinpath(rpath,"str_walk-vec.txt")
5853

59-
symbols = Iterators.flatten(walks) |> Set
6054
writedlm(the_walks,str_walks)
6155
word2vec(the_walks,the_vecs,verbose=true,size=dims)
6256
model=wordvectors(the_vecs)
@@ -69,7 +63,7 @@ end
6963
"""
7064
Conducts a random walk over `g` in O(l) time,
7165
weighted by alias sampling probabilities `alias_nodes`
72-
and `alias_edges`
66+
and `alias_edges`.
7367
"""
7468
function node2vec_walk(
7569
g::FeaturedGraph,
@@ -78,10 +72,9 @@ function node2vec_walk(
7872
start_node::Int,
7973
walk_length::Int)::Vector{Int}
8074
walk::Vector{Int} = [start_node]
81-
current::Int = start_node
8275
for _ in 2:walk_length
8376
curr = walk[end]
84-
cur_nbrs = sort(outneighbors(g, curr))
77+
cur_nbrs = sort(neighbors(g, curr; dir=:out))
8578
if length(walk) == 1
8679
push!(walk, cur_nbrs[alias_sample(alias_nodes[curr]...)])
8780
else
@@ -93,9 +86,11 @@ function node2vec_walk(
9386
return walk
9487
end
9588

96-
"Returns J and q for a given edge"
89+
"""
90+
Returns J and q for a given edge
91+
"""
9792
function get_alias_edge(g::FeaturedGraph, src::Int, dst::Int, p::Float64, q::Float64)::Alias
98-
unnormalized_probs = spzeros(length(outneighbors(g, dst)))
93+
unnormalized_probs = spzeros(length(neighbors(g, dst; dir=:out)))
9994
neighbor_weight_pairs = zip(weighted_outneighbors(g, dst)...)
10095
for (i, (dst_nbr, weight)) in enumerate(neighbor_weight_pairs)
10196
if dst_nbr == src
@@ -110,36 +105,45 @@ function get_alias_edge(g::FeaturedGraph, src::Int, dst::Int, p::Float64, q::Flo
110105
return alias_setup(normalized_probs)
111106
end
112107

108+
# Returns (neighbors::Vector{Int}, weights::Vector{Float64})
109+
function weighted_outneighbors(fg::FeaturedGraph, i::Int)
110+
nbrs = neighbors(fg, i; dir=:out)
111+
nbrs, sparse(graph(fg))[i, nbrs]
112+
end
113+
113114
"""
114115
Computes weighted probability transition aliases J and q for nodes and edges
115116
using return parameter `p` and In-out parameter `q`
116117
117-
Implementation as specified in the node2vec paper by Grover and Leskovec (2016)
118-
https://cs.stanford.edu/~jure/pubs/node2vec-kdd16.pdf
118+
Implementation as specified in the node2vec paper [node2vec2016](@cite).
119119
"""
120-
function preprocess_modified_weights(g::FeaturedGraph, p::Float64, q::Float64)
120+
function preprocess_modified_weights(g::FeaturedGraph, p::Real, q::Real)
121121

122-
alias_nodes::Dict{Int, Alias} = Dict()
123-
alias_edges::Dict{Tuple{Int, Int}, Alias} = Dict()
122+
alias_nodes = Dict{Int, Alias}()
123+
alias_edges = Dict{Tuple{Int, Int}, Alias}()
124124

125125
for node in 1:nv(g)
126-
probs = [1 / length(outneighbors(g, node)) for _ in outneighbors(g, node)]
126+
nbrs = neighbors(g, node, dir=:out)
127+
probs = fill(1, length(nbrs)) ./ length(nbrs)
127128
alias_nodes[node] = alias_setup(probs)
128129
end
129-
for edge in edges(g)
130-
alias_edges[(edge.src, edge.dst)] = get_alias_edge(g, edge.src, edge.dst, p, q)
130+
for (_, edge) in edges(g)
131+
src, dst = edge
132+
alias_edges[(src, dst)] = get_alias_edge(g, src, dst, p, q)
131133
if !is_directed(g)
132-
alias_edges[(edge.dst, edge.src)] = get_alias_edge(g, edge.dst, edge.src, p, q)
134+
alias_edges[(dst, src)] = get_alias_edge(g, dst, src, p, q)
133135
end
134136
end
135137
return alias_nodes, alias_edges
136138
end
137139

138140

139-
"Given a graph, compute `walks_per_node` * nv(g) random walks."
140-
function simulate_walks(g::FeaturedGraph; walks_per_node::Int, len::Int, p::Float64, q::Float64)::Vector{Vector{Int}}
141+
"""
142+
Given a graph, compute `walks_per_node` * nv(g) random walks.
143+
"""
144+
function simulate_walks(g::FeaturedGraph; walks_per_node::Int, len::Int, p::Real, q::Real)::Vector{Vector{Int}}
141145
alias_nodes, alias_edges = preprocess_modified_weights(g, p, q)
142-
walks::Vector{Vector{Int}} = []
146+
walks = Vector{Int}[]
143147
for _ in 1:walks_per_node
144148
for node in shuffle(1:nv(g))
145149
walk::Vector{Int} = node2vec_walk(g, alias_nodes, alias_edges; start_node=node, walk_length=len)

src/graph_embedding/sampling.jl renamed to src/sampling.jl

Lines changed: 17 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,17 @@
11
"""
2-
Alias Sampling first described in [1]. [2] might be a helpful resource to understand alias sampling.
3-
4-
[1] A. Kronmal and A. V. Peterson. On the alias method for generating random variables from a discrete distribution. The American Statistician, 33(4):214-218, 1979.
5-
[2] https://lips.cs.princeton.edu/the-alias-method-efficient-sampling-with-many-discrete-outcomes/
6-
"""
2+
alias_setup(probs)
73
8-
alias_setup(probs::Vector{Float64}) = alias_setup(sparse(probs))
9-
10-
"""
114
Computes alias probabilities.
125
"""
13-
function alias_setup(probs::SparseVector{Float64})::Tuple{SparseVector{Int}, SparseVector{Float64}}
6+
alias_setup(probs::AbstractVector{<:Real}) = alias_setup(sparse(probs))
7+
8+
function alias_setup(probs::SparseVector{<:Real})
149
K = length(probs)
1510
J = spzeros(Int, K)
1611
q = probs * K
1712

18-
smaller::Vector{Int} = [] # prob idxs < 1/K
19-
larger::Vector{Int} = [] # prob idxs >= 1/k
13+
smaller = Int[] # prob idxs < 1/K
14+
larger = Int[] # prob idxs >= 1/k
2015

2116
for i in 1:length(probs)
2217
if q[i] < 1.0 # equivalent to prob < 1/K but saves the division
@@ -25,6 +20,7 @@ function alias_setup(probs::SparseVector{Float64})::Tuple{SparseVector{Int}, Spa
2520
push!(larger, i)
2621
end
2722
end
23+
2824
while length(smaller) > 0 && length(larger) > 0
2925
small = pop!(smaller)
3026
large = pop!(larger)
@@ -40,9 +36,17 @@ function alias_setup(probs::SparseVector{Float64})::Tuple{SparseVector{Int}, Spa
4036
return J, q
4137
end
4238

43-
function alias_sample(J::SparseVector{Int}, q::SparseVector{Float64})::Int
39+
"""
40+
alias_sample(J, q)
41+
42+
Alias Sampling first described in [1]. [2] might be a helpful resource to understand alias sampling.
4443
45-
small_index = rand() * length(J) |> ceil |> Int
44+
[1] A. Kronmal and A. V. Peterson. On the alias method for generating random variables from a
45+
discrete distribution. The American Statistician, 33(4):214-218, 1979.
46+
[2] https://lips.cs.princeton.edu/the-alias-method-efficient-sampling-with-many-discrete-outcomes/
47+
"""
48+
function alias_sample(J::AbstractVector{<:Integer}, q::AbstractVector{<:Real})
49+
small_index = ceil(Int, rand() * length(J))
4650
if rand() < q[small_index]
4751
return small_index
4852
else

src/utils.jl

Lines changed: 0 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -20,84 +20,3 @@ end
2020

2121
@non_differentiable accumulated_edges(x...)
2222
@non_differentiable generate_cluster(x...)
23-
24-
"""
25-
edge_index_table(adj[, directed])
26-
27-
Generate a mapping from vertex pair (i, j) to edge index. The edge indecies are determined by
28-
the sorted vertex indecies.
29-
"""
30-
function edge_index_table(adj::AbstractVector{<:AbstractVector{<:Integer}}, directed::Bool=is_directed(adj))
31-
table = Dict{Tuple{UInt32,UInt32},UInt64}()
32-
e = one(UInt64)
33-
if directed
34-
for (i, js) = enumerate(adj)
35-
js = sort(js)
36-
for j = js
37-
table[(i, j)] = e
38-
e += one(UInt64)
39-
end
40-
end
41-
else
42-
for (i, js) = enumerate(adj)
43-
js = sort(js)
44-
js = js[i .≤ js]
45-
for j = js
46-
table[(i, j)] = e
47-
table[(j, i)] = e
48-
e += one(UInt64)
49-
end
50-
end
51-
end
52-
table
53-
end
54-
55-
function edge_index_table(vpair::AbstractVector{<:Tuple})
56-
table = Dict{Tuple{UInt32,UInt32},UInt64}()
57-
for (i, p) = enumerate(vpair)
58-
table[p] = i
59-
end
60-
table
61-
end
62-
63-
edge_index_table(fg::FeaturedGraph) = edge_index_table(fg.graph, fg.directed)
64-
65-
Zygote.@nograd edge_index_table
66-
67-
### TODO move these to GraphSignals ######
68-
import GraphSignals: FeaturedGraph
69-
70-
# function FeaturedGraph(fg::FeaturedGraph;
71-
# nf=node_feature(fg),
72-
# ef=edge_feature(fg),
73-
# gf=global_feature(fg))
74-
75-
# return FeaturedGraph(graph(fg); nf, ef, gf)
76-
# end
77-
78-
79-
function edges(fg::FeaturedGraph)
80-
edges = []
81-
for (src, vec) in enumerate(adjacency_list(GraphSignals.adjacency_matrix(fg)))
82-
for v in vec
83-
push!(edges, Edge(src, v))
84-
end
85-
end
86-
edges
87-
end
88-
89-
Graphs.has_edge(fg::FeaturedGraph, u::Int, v::Int) = has_edge(graph(fg), u, v)
90-
91-
# Returns (neighbors::Vector{Int}, weights::Vector{Float64})
92-
function weighted_outneighbors(fg::FeaturedGraph, v::Int)
93-
nbrs = neighbors(fg,v; dir=:out)
94-
nbrs, graph(fg).S[v, nbrs]
95-
end
96-
97-
function check_num_nodes(fg::FeaturedGraph, x::AbstractArray)
98-
@assert nv(fg) == size(x, ndims(x))
99-
end
100-
101-
function check_num_edges(fg::FeaturedGraph, e::AbstractArray)
102-
@assert ne(fg) == size(e, ndims(e))
103-
end
File renamed without changes.

0 commit comments

Comments
 (0)