Skip to content

Commit 16d43e5

Browse files
authored
Merge pull request #247 from jarbus/node2vec
Node2vec prototype
2 parents 2d5de77 + 3b24f6e commit 16d43e5

File tree

12 files changed

+329
-7
lines changed

12 files changed

+329
-7
lines changed

Project.toml

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ version = "0.8.0"
77
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
88
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
99
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
10+
DelimitedFiles = "8bb1440f-4735-579b-a4ab-409b98df4dab"
1011
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
1112
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
1213
GraphMLDatasets = "21828b05-d3b3-40ad-870e-a4bc2f52d5e8"
@@ -17,7 +18,9 @@ NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
1718
NNlibCUDA = "a00861dc-f156-4864-bf3c-e6376f28a68d"
1819
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1920
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
21+
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
2022
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
23+
Word2Vec = "c64b6f0f-98cd-51d1-af78-58ae84944834"
2124
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
2225

2326
[compat]
@@ -27,17 +30,18 @@ DataStructures = "0.18"
2730
FillArrays = "0.12"
2831
Flux = "0.12"
2932
GraphMLDatasets = "0.1"
30-
GraphSignals = "0.3"
3133
Graphs = "1.4"
3234
NNlib = "0.7"
3335
NNlibCUDA = "0.1"
3436
Reexport = "1.1"
37+
Word2Vec = "0.5"
3538
Zygote = "0.6"
3639
julia = "1.6"
3740

3841
[extras]
42+
Clustering = "aaaa29a8-35af-508c-8bc3-b662a17a0fe5"
3943
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
4044
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
4145

4246
[targets]
43-
test = ["SparseArrays", "Test"]
47+
test = ["Clustering", "SparseArrays", "Test"]

docs/Project.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
[deps]
22
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
3+
DocumenterCitations = "daee34ce-89f3-4625-b898-19384cb65244"
34

45
[compat]
5-
Documenter = "0.24"
6+
Documenter = "0.27"

docs/bibliography.bib

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
@inproceedings{node2vec2016,
2+
author = {Grover, Aditya and Leskovec, Jure},
3+
title = {Node2vec: Scalable Feature Learning for Networks},
4+
year = {2016},
5+
isbn = {9781450342322},
6+
publisher = {Association for Computing Machinery},
7+
address = {New York, NY, USA},
8+
url = {https://doi.org/10.1145/2939672.2939754},
9+
doi = {10.1145/2939672.2939754},
10+
abstract = {Prediction tasks over nodes and edges in networks require careful effort in engineering features used by learning algorithms. Recent research in the broader field of representation learning has led to significant progress in automating prediction by learning the features themselves. However, present feature learning approaches are not expressive enough to capture the diversity of connectivity patterns observed in networks. Here we propose node2vec, an algorithmic framework for learning continuous feature representations for nodes in networks. In node2vec, we learn a mapping of nodes to a low-dimensional space of features that maximizes the likelihood of preserving network neighborhoods of nodes. We define a flexible notion of a node's network neighborhood and design a biased random walk procedure, which efficiently explores diverse neighborhoods. Our algorithm generalizes prior work which is based on rigid notions of network neighborhoods, and we argue that the added flexibility in exploring neighborhoods is the key to learning richer representations.We demonstrate the efficacy of node2vec over existing state-of-the-art techniques on multi-label classification and link prediction in several real-world networks from diverse domains. Taken together, our work represents a new way for efficiently learning state-of-the-art task-independent representations in complex networks.},
11+
booktitle = {Proceedings of the 22nd ACM SIGKDD International Conference on Knowledge Discovery and Data Mining},
12+
pages = {855–864},
13+
numpages = {10},
14+
keywords = {node embeddings, information networks, graph representations, feature learning},
15+
location = {San Francisco, California, USA},
16+
series = {KDD '16}
17+
}

docs/make.jl

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,11 @@
11
using Documenter
2+
using DocumenterCitations
23
using GeometricFlux
34

5+
bib = CitationBibliography(joinpath(@__DIR__, "bibliography.bib"), sorting=:nyt)
6+
47
makedocs(
8+
bib,
59
sitename = "GeometricFlux.jl",
610
format = Documenter.HTML(
711
assets = ["assets/flux.css"],
@@ -24,7 +28,8 @@ makedocs(
2428
["Convolutional Layers" => "manual/conv.md",
2529
"Pooling Layers" => "manual/pool.md",
2630
"Models" => "manual/models.md",
27-
"Linear Algebra" => "manual/linalg.md"]
31+
"Linear Algebra" => "manual/linalg.md"],
32+
"References" => "references.md",
2833
]
2934
)
3035

docs/src/references.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
# References
2+
3+
```@bibliography
4+
```

examples/node2vec.jl

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
using GeometricFlux
2+
using GraphSignals
3+
using Graphs
4+
using SparseArrays
5+
using Plots
6+
using GraphPlot
7+
using Clustering
8+
using Cairo, Compose
9+
10+
clusters = [1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 1, 1, 1, 1, 2, 2, 1, 1, 2, 1, 2, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2]
11+
12+
int2col_str(x::Int) = x==1 ? "lightblue" : "red"
13+
14+
15+
g = smallgraph(:karate)
16+
fg = FeaturedGraph(g)
17+
vectors = node2vec(fg; walks_per_node=10, len=80, p=1.0, q=1.0)
18+
R = kmeans(vectors, 2)
19+
20+
21+
learned_clusters = copy(assignments(R))
22+
# ensure that the cluster containing node 1 is cluster 1
23+
if assignments(R)[1] != 1
24+
learned_clusters = [i == 1 ? 2 : 1 for i in assignments(R)]
25+
end
26+
27+
output_plot_name = "karateclub.pdf"
28+
draw(
29+
PDF(output_plot_name, 16cm, 16cm),
30+
gplot(g,
31+
nodelabel=map(string, 1:34),
32+
nodefillc=[int2col_str(learned_clusters[i]) for i in 1:34],
33+
nodestrokec=["white" for _ in 1:34]
34+
)
35+
)
36+
37+
incorrect = sum(learned_clusters .!= clusters)
38+
println(incorrect, " incorrect cluster labelings")
39+
println("Drawn graph to ", output_plot_name)

src/GeometricFlux.jl

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,24 @@
11
module GeometricFlux
22

3+
using DelimitedFiles
4+
using SparseArrays
35
using Statistics: mean
46
using LinearAlgebra: Adjoint, norm, Transpose
7+
using Random
58
using Reexport
69

710
using CUDA
811
using ChainRulesCore: @non_differentiable
912
using FillArrays: Fill
1013
using Flux
1114
using Flux: glorot_uniform, leakyrelu, GRUCell, @functor
12-
using NNlib, NNlibCUDA
13-
using GraphSignals
15+
@reexport using GraphSignals
1416
using Graphs
17+
using NNlib, NNlibCUDA
1518
using Zygote
1619

20+
import Word2Vec: word2vec, wordvectors, get_vector
21+
1722
export
1823
# layers/graphlayers
1924
AbstractGraphLayer,
@@ -52,7 +57,10 @@ export
5257
bypass_graph,
5358

5459
# utils
55-
generate_cluster
60+
generate_cluster,
61+
62+
#node2vec
63+
node2vec
5664

5765
include("datasets.jl")
5866

@@ -67,6 +75,9 @@ include("layers/pool.jl")
6775
include("models.jl")
6876
include("layers/misc.jl")
6977

78+
include("sampling.jl")
79+
include("embedding/node2vec.jl")
80+
7081
include("cuda/conv.jl")
7182

7283
using .Datasets

src/embedding/node2vec.jl

Lines changed: 154 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,154 @@
1+
const Alias = Tuple{SparseVector{Int}, SparseVector{Float64}}
2+
3+
"""
4+
node2vec(g; walks_per_node, len, p, q, dims)
5+
6+
Returns an embedding matrix with size of `nv(g)` x `dims`. It computes node embeddings
7+
on graph `g` accroding to node2vec [node2vec2016](@cite). It performs biased random walks on the graph,
8+
then computes word embeddings by treating those random walks as sentences.
9+
10+
# Arguments
11+
12+
- `g::FeaturedGraph`: The graph to perform random walk on.
13+
- `walks_per_node::Int`: Number of walks starting on each node,
14+
total number of walks is `nv(g) * walks_per_node`
15+
- `len::Int`: Length of random walks
16+
- `p::Real`: Return parameter from [node2vec2016](@cite)
17+
- `q::Real`: In-out parameter from [node2vec2016](@cite)
18+
- `dims::Int`: Number of vector dimensions
19+
"""
20+
function node2vec(g::FeaturedGraph; walks_per_node::Int=100, len::Int=5, p::Real=0.5, q::Real=0.5, dims::Int=128)
21+
walks = simulate_walks(g; walks_per_node=walks_per_node, len=len, p=p, q=q)
22+
model = walks2vec(walks,dims=dims)
23+
vecs = []
24+
println(typeof(model))
25+
for i in 1:nv(g)
26+
push!(vecs, get_vector(model, string(i)))
27+
end
28+
matrix = cat(vecs..., dims=2)
29+
return matrix
30+
end
31+
32+
"""
33+
Modified version of Node2Vec.learn_embeddings[1]. Uses
34+
a Julia interface[2] to the original word2vec C code[3].
35+
36+
Treats each random walk like a sentence, and computed word
37+
embeddings using node ID as words.
38+
39+
[1] https://github.com/ollin18/Node2Vec.jl
40+
[2] https://github.com/JuliaText/Word2Vec.jl
41+
[3] https://code.google.com/archive/p/word2vec/
42+
"""
43+
function walks2vec(walks::Vector{Vector{Int}}; dims::Int=100)
44+
str_walks=map(x -> string.(x),walks)
45+
46+
if Sys.iswindows()
47+
rpath = pwd()
48+
else
49+
rpath = "/tmp"
50+
end
51+
the_walks = joinpath(rpath,"str_walk.txt")
52+
the_vecs = joinpath(rpath,"str_walk-vec.txt")
53+
54+
writedlm(the_walks,str_walks)
55+
word2vec(the_walks,the_vecs,verbose=true,size=dims)
56+
model=wordvectors(the_vecs)
57+
rm(the_walks)
58+
rm(the_vecs)
59+
model
60+
end
61+
62+
63+
"""
64+
Conducts a random walk over `g` in O(l) time,
65+
weighted by alias sampling probabilities `alias_nodes`
66+
and `alias_edges`.
67+
"""
68+
function node2vec_walk(
69+
g::FeaturedGraph,
70+
alias_nodes::Dict{Int, Alias},
71+
alias_edges::Dict{Tuple{Int, Int}, Alias};
72+
start_node::Int,
73+
walk_length::Int)::Vector{Int}
74+
walk::Vector{Int} = [start_node]
75+
for _ in 2:walk_length
76+
curr = walk[end]
77+
cur_nbrs = sort(neighbors(g, curr; dir=:out))
78+
if length(walk) == 1
79+
push!(walk, cur_nbrs[alias_sample(alias_nodes[curr]...)])
80+
else
81+
prev = walk[end-1]
82+
next = cur_nbrs[alias_sample(alias_edges[(prev, curr)]...)]
83+
push!(walk, next)
84+
end
85+
end
86+
return walk
87+
end
88+
89+
"""
90+
Returns J and q for a given edge
91+
"""
92+
function get_alias_edge(g::FeaturedGraph, src::Int, dst::Int, p::Float64, q::Float64)::Alias
93+
unnormalized_probs = spzeros(length(neighbors(g, dst; dir=:out)))
94+
neighbor_weight_pairs = zip(weighted_outneighbors(g, dst)...)
95+
for (i, (dst_nbr, weight)) in enumerate(neighbor_weight_pairs)
96+
if dst_nbr == src
97+
unnormalized_probs[i] = weight/p
98+
elseif has_edge(g, dst_nbr, src)
99+
unnormalized_probs[i] = weight
100+
else
101+
unnormalized_probs[i] = weight/q
102+
end
103+
end
104+
normalized_probs = unnormalized_probs ./ sum(unnormalized_probs)
105+
return alias_setup(normalized_probs)
106+
end
107+
108+
# Returns (neighbors::Vector{Int}, weights::Vector{Float64})
109+
function weighted_outneighbors(fg::FeaturedGraph, i::Int)
110+
nbrs = neighbors(fg, i; dir=:out)
111+
nbrs, sparse(graph(fg))[i, nbrs]
112+
end
113+
114+
"""
115+
Computes weighted probability transition aliases J and q for nodes and edges
116+
using return parameter `p` and In-out parameter `q`
117+
118+
Implementation as specified in the node2vec paper [node2vec2016](@cite).
119+
"""
120+
function preprocess_modified_weights(g::FeaturedGraph, p::Real, q::Real)
121+
122+
alias_nodes = Dict{Int, Alias}()
123+
alias_edges = Dict{Tuple{Int, Int}, Alias}()
124+
125+
for node in 1:nv(g)
126+
nbrs = neighbors(g, node, dir=:out)
127+
probs = fill(1, length(nbrs)) ./ length(nbrs)
128+
alias_nodes[node] = alias_setup(probs)
129+
end
130+
for (_, edge) in edges(g)
131+
src, dst = edge
132+
alias_edges[(src, dst)] = get_alias_edge(g, src, dst, p, q)
133+
if !is_directed(g)
134+
alias_edges[(dst, src)] = get_alias_edge(g, dst, src, p, q)
135+
end
136+
end
137+
return alias_nodes, alias_edges
138+
end
139+
140+
141+
"""
142+
Given a graph, compute `walks_per_node` * nv(g) random walks.
143+
"""
144+
function simulate_walks(g::FeaturedGraph; walks_per_node::Int, len::Int, p::Real, q::Real)::Vector{Vector{Int}}
145+
alias_nodes, alias_edges = preprocess_modified_weights(g, p, q)
146+
walks = Vector{Int}[]
147+
for _ in 1:walks_per_node
148+
for node in shuffle(1:nv(g))
149+
walk::Vector{Int} = node2vec_walk(g, alias_nodes, alias_edges; start_node=node, walk_length=len)
150+
push!(walks, walk)
151+
end
152+
end
153+
return walks
154+
end

src/sampling.jl

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
"""
2+
alias_setup(probs)
3+
4+
Computes alias probabilities.
5+
"""
6+
alias_setup(probs::AbstractVector{<:Real}) = alias_setup(sparse(probs))
7+
8+
function alias_setup(probs::SparseVector{<:Real})
9+
K = length(probs)
10+
J = spzeros(Int, K)
11+
q = probs * K
12+
13+
smaller = Int[] # prob idxs < 1/K
14+
larger = Int[] # prob idxs >= 1/k
15+
16+
for i in 1:length(probs)
17+
if q[i] < 1.0 # equivalent to prob < 1/K but saves the division
18+
push!(smaller, i)
19+
else
20+
push!(larger, i)
21+
end
22+
end
23+
24+
while length(smaller) > 0 && length(larger) > 0
25+
small = pop!(smaller)
26+
large = pop!(larger)
27+
J[small] = large
28+
q[large] = q[large] + q[small] - 1.0
29+
if q[large] < 1.0
30+
push!(smaller, large)
31+
else
32+
push!(larger, large)
33+
end
34+
end
35+
36+
return J, q
37+
end
38+
39+
"""
40+
alias_sample(J, q)
41+
42+
Alias Sampling first described in [1]. [2] might be a helpful resource to understand alias sampling.
43+
44+
[1] A. Kronmal and A. V. Peterson. On the alias method for generating random variables from a
45+
discrete distribution. The American Statistician, 33(4):214-218, 1979.
46+
[2] https://lips.cs.princeton.edu/the-alias-method-efficient-sampling-with-many-discrete-outcomes/
47+
"""
48+
function alias_sample(J::AbstractVector{<:Integer}, q::AbstractVector{<:Real})
49+
small_index = ceil(Int, rand() * length(J))
50+
if rand() < q[small_index]
51+
return small_index
52+
else
53+
return J[small_index]
54+
end
55+
end

0 commit comments

Comments
 (0)