Skip to content

Commit 4494757

Browse files
committed
clustering pytorch
1 parent d33896c commit 4494757

File tree

2 files changed

+194
-77
lines changed

2 files changed

+194
-77
lines changed

Clustering_pytorch.py

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
import os.path as osp
2+
import torch
3+
from torch.nn import Linear
4+
5+
import torch_geometric.transforms as T
6+
from torch_geometric.datasets import Planetoid
7+
from torch_geometric.nn import GraphConv, dense_mincut_pool
8+
from torch_geometric import utils
9+
from torch_geometric.nn import Sequential
10+
from torch_geometric.nn.conv.gcn_conv import gcn_norm
11+
12+
from sklearn.metrics import normalized_mutual_info_score as NMI
13+
14+
torch.manual_seed(0) # for reproducibility
15+
16+
# Load data
17+
dataset = 'Cora'
18+
path = osp.join(osp.dirname(osp.realpath(__file__)), '.', 'data', dataset)
19+
dataset = Planetoid(path, dataset, transform=T.NormalizeFeatures())
20+
data = dataset[0]
21+
22+
# Normalized adjacency matrix
23+
data.edge_index, data.edge_weight = gcn_norm(
24+
data.edge_index, data.edge_weight, data.num_nodes,
25+
add_self_loops=False, dtype=data.x.dtype)
26+
27+
class Net(torch.nn.Module):
28+
def __init__(self,
29+
mp_units,
30+
mp_act,
31+
in_channels,
32+
n_clusters,
33+
mlp_units=[],
34+
mlp_act="Identity"):
35+
super().__init__()
36+
37+
mp_act = getattr(torch.nn, mp_act)(inplace=True)
38+
mlp_act = getattr(torch.nn, mlp_act)(inplace=True)
39+
40+
# Message passing layers
41+
mp = [
42+
(GraphConv(in_channels, mp_units[0]), 'x, edge_index, edge_weight -> x'),
43+
mp_act
44+
]
45+
for i in range(len(mp_units)-1):
46+
mp.append((GraphConv(mp_units[i], mp_units[i+1]), 'x, edge_index, edge_weight -> x'))
47+
mp.append(mp_act)
48+
self.mp = Sequential('x, edge_index, edge_weight', mp)
49+
out_chan = mp_units[-1]
50+
51+
# MLP layers
52+
self.mlp = torch.nn.Sequential()
53+
for units in mlp_units:
54+
self.mlp.append(Linear(out_chan, units))
55+
out_chan = units
56+
self.mlp.append(mlp_act)
57+
self.mlp.append(Linear(out_chan, n_clusters))
58+
59+
60+
def forward(self, x, edge_index, edge_weight):
61+
62+
# Propagate node feats
63+
x = self.mp(x, edge_index, edge_weight)
64+
65+
# Cluster assignments (logits)
66+
s = self.mlp(x)
67+
68+
# Obtain MinCutPool losses
69+
adj = utils.to_dense_adj(edge_index, edge_attr=edge_weight)
70+
_, _, mc_loss, o_loss = dense_mincut_pool(x, adj, s)
71+
72+
return torch.softmax(s, dim=-1), mc_loss, o_loss
73+
74+
75+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
76+
data = data.to(device)
77+
model = Net([16], "ELU", dataset.num_features, dataset.num_classes).to(device)
78+
print(model)
79+
optimizer = torch.optim.Adam(model.parameters(), lr=1e-2)
80+
81+
82+
def train():
83+
model.train()
84+
optimizer.zero_grad()
85+
_, mc_loss, o_loss = model(data.x, data.edge_index, data.edge_weight)
86+
loss = mc_loss + o_loss
87+
loss.backward()
88+
optimizer.step()
89+
return loss.item()
90+
91+
92+
@torch.no_grad()
93+
def test():
94+
model.eval()
95+
clust, _, _ = model(data.x, data.edge_index, data.edge_weight)
96+
return NMI(clust.max(1)[1].cpu(), data.y.cpu())
97+
98+
99+
patience = 50
100+
best_nmi = 0
101+
for epoch in range(1, 10000):
102+
train_loss = train()
103+
nmi = test()
104+
print(f'Epoch: {epoch:03d}, Loss: {train_loss:.4f}, NMI: {nmi:.3f}')
105+
if nmi > best_nmi:
106+
best_nmi = nmi
107+
patience = 50
108+
else:
109+
patience -= 1
110+
if patience == 0:
111+
break

README.md

Lines changed: 83 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -1,77 +1,83 @@
1-
# Spectral Clustering with Graph Neural Networks for Graph Pooling
2-
3-
<img src="./figs/mincutpool.png" width="400" height="200">
4-
5-
This code reproduces the experimental results obtained with the MinCutPool layer
6-
as presented in the ICML 2020 paper
7-
8-
[Spectral Clustering with Graph Neural Networks for Graph Pooling](https://arxiv.org/pdf/1907.00481.pdf)
9-
F. M. Bianchi*, D. Grattarola*, C. Alippi
10-
11-
The official implementation of the MinCutPool layer can be found in
12-
[Spektral](https://graphneural.network/layers/pooling/#mincutpool).
13-
14-
An implementation of MinCutPool for PyTorch is also available in
15-
[Pytorch Geometric](https://pytorch-geometric.readthedocs.io/en/latest/modules/nn.html#torch_geometric.nn.dense.dense_mincut_pool).
16-
17-
## Setup
18-
19-
The code is based on Python 3.5, TensorFlow 1.15, and Spektral 0.1.2.
20-
All required libraries are listed in `requirements.txt` and can be installed with
21-
22-
```bash
23-
pip install -r requirements.txt
24-
```
25-
26-
## Image segmentation
27-
28-
<img src="./figs/overseg_and_rag.png" width="700" height="150">
29-
30-
Run [Segmentation.py](https://github.com/FilippoMB/Spectral-Clustering-with-Graph-Neural-Networks-for-Graph-Pooling/blob/master/Segmentation.py)
31-
to perform hyper-segmentation, generate a Region Adjacency Graph from the
32-
resulting segments, and then cluster the nodes of the RAG graph with the
33-
MinCutPool layer.
34-
35-
## Clustering
36-
37-
<img src="./figs/clustering_stats.png" width="600" height="250">
38-
39-
Run [Clustering.py](https://github.com/FilippoMB/Spectral-Clustering-with-Graph-Neural-Networks-for-Graph-Pooling/blob/master/Clustering.py)
40-
to cluster the nodes of a citation network. The datasets `cora`, `citeseer`, and
41-
`pubmed` can be selected.
42-
Results are provided in terms of homogeneity score, completeness score, and
43-
normalized mutual information (v-score).
44-
45-
## Autoencoder
46-
47-
<img src="./figs/ae_ring.png" width="400" height="200">
48-
<img src="./figs/ae_grid.png" width="400" height="200">
49-
50-
Run [Autoencoder.py](https://github.com/FilippoMB/Spectral-Clustering-with-Graph-Neural-Networks-for-Graph-Pooling/blob/master/Autoencoder.py)
51-
to train an autoencoder with bottleneck and compute the reconstructed graph. It
52-
is possible to switch between the `ring` and `grid` graphs, but also any other
53-
[point clouds](https://pygsp.readthedocs.io/en/stable/reference/graphs.html?highlight=bunny#graph-models)
54-
from the [PyGSP](https://pygsp.readthedocs.io/en/stable/index.html) library
55-
are supported. Results are provided in terms of the Mean Squared Error.
56-
57-
## Graph Classification
58-
59-
Run [Graph_Classification.py](https://github.com/FilippoMB/Spectral-Clustering-with-Graph-Neural-Networks-for-Graph-Pooling/blob/master/Graph_Classification.py) to train a graph classifier. Additional classification datasets are available [here](https://chrsmrrs.github.io/datasets/) (drop them in ````data/classification/````) and [here](https://github.com/FilippoMB/Benchmark_dataset_for_graph_classification) (drop them in ````data/````).
60-
Results are provided in terms of classification accuracy averaged over 10 runs.
61-
62-
## Citation
63-
64-
Please, cite the original paper if you are using MinCutPool in your research
65-
66-
@inproceedings{bianchi2020mincutpool,
67-
title={Spectral Clustering with Graph Neural Networks for Graph Pooling},
68-
author={Bianchi, Filippo Maria and Grattarola, Daniele and Alippi, Cesare},
69-
booktitle={Proceedings of the 37th international conference on Machine learning},
70-
pages={2729-2738},
71-
year={2020},
72-
organization={ACM}
73-
}
74-
75-
## License
76-
77-
The code is released under the MIT License. See the attached LICENSE file.
1+
# Spectral Clustering with Graph Neural Networks for Graph Pooling
2+
3+
<img src="./figs/mincutpool.png" width="400" height="200">
4+
5+
This code reproduces the experimental results obtained with the MinCutPool layer
6+
as presented in the ICML 2020 paper
7+
8+
[Spectral Clustering with Graph Neural Networks for Graph Pooling](https://arxiv.org/abs/1907.00481)
9+
F. M. Bianchi*, D. Grattarola*, C. Alippi
10+
11+
The official Tensorflow implementation of the MinCutPool layer is in
12+
[Spektral](https://graphneural.network/layers/pooling/#mincutpool).
13+
14+
The PyTorch implementation of MinCutPool is in
15+
[Pytorch Geometric](https://pytorch-geometric.readthedocs.io/en/latest/modules/nn.html#torch_geometric.nn.dense.mincut_pool.dense_mincut_pool).
16+
17+
## Setup
18+
19+
The code is based on Python 3.5, TensorFlow 1.15, and Spektral 0.1.2.
20+
All required libraries are listed in `requirements.txt` and can be installed with
21+
22+
```bash
23+
pip install -r requirements.txt
24+
```
25+
26+
## Image segmentation
27+
28+
<img src="./figs/overseg_and_rag.png" width="700" height="150">
29+
30+
Run [Segmentation.py](https://github.com/FilippoMB/Spectral-Clustering-with-Graph-Neural-Networks-for-Graph-Pooling/blob/master/Segmentation.py)
31+
to perform hyper-segmentation, generate a Region Adjacency Graph from the
32+
resulting segments, and then cluster the nodes of the RAG graph with the
33+
MinCutPool layer.
34+
35+
## Clustering
36+
37+
<img src="./figs/clustering_stats.png" width="600" height="250">
38+
39+
Run [Clustering.py](https://github.com/FilippoMB/Spectral-Clustering-with-Graph-Neural-Networks-for-Graph-Pooling/blob/master/Clustering.py)
40+
to cluster the nodes of a citation network. The datasets `cora`, `citeseer`, and
41+
`pubmed` can be selected.
42+
Results are provided in terms of homogeneity score, completeness score, and
43+
normalized mutual information (v-score).
44+
45+
### Pytorch
46+
[Clustering_pytorch.py](https://github.com/FilippoMB/Spectral-Clustering-with-Graph-Neural-Networks-for-Graph-Pooling/blob/master/Clustering_pytorch.py) contains a basic implementation in Pytorch based on [Pytorch Geometric](https://pytorch-geometric.readthedocs.io/en/latest/modules/nn.html#torch_geometric.nn.dense.mincut_pool.dense_mincut_pool).
47+
48+
## Autoencoder
49+
50+
<img src="./figs/ae_ring.png" width="400" height="200">
51+
<img src="./figs/ae_grid.png" width="400" height="200">
52+
53+
Run [Autoencoder.py](https://github.com/FilippoMB/Spectral-Clustering-with-Graph-Neural-Networks-for-Graph-Pooling/blob/master/Autoencoder.py)
54+
to train an autoencoder with bottleneck and compute the reconstructed graph. It
55+
is possible to switch between the `ring` and `grid` graphs, but also any other
56+
[point clouds](https://pygsp.readthedocs.io/en/stable/reference/graphs.html?highlight=bunny#graph-models)
57+
from the [PyGSP](https://pygsp.readthedocs.io/en/stable/index.html) library
58+
are supported. Results are provided in terms of the Mean Squared Error.
59+
60+
## Graph Classification
61+
62+
Run [Graph_Classification.py](https://github.com/FilippoMB/Spectral-Clustering-with-Graph-Neural-Networks-for-Graph-Pooling/blob/master/Graph_Classification.py) to train a graph classifier. Additional classification datasets are available [here](https://chrsmrrs.github.io/datasets/) (drop them in ````data/classification/````) and [here](https://github.com/FilippoMB/Benchmark_dataset_for_graph_classification) (drop them in ````data/````).
63+
Results are provided in terms of classification accuracy averaged over 10 runs.
64+
65+
### Pytorch
66+
A basic Pytorch implementation of the graph classification task can be found in this [example](https://github.com/pyg-team/pytorch_geometric/blob/a238110ff5ac772656c967f135fa138add6dabb4/examples/proteins_mincut_pool.py) from Pytorch Geometric.
67+
68+
## Citation
69+
70+
Please, cite the original paper if you are using MinCutPool in your research
71+
72+
@inproceedings{bianchi2020mincutpool,
73+
title={Spectral Clustering with Graph Neural Networks for Graph Pooling},
74+
author={Bianchi, Filippo Maria and Grattarola, Daniele and Alippi, Cesare},
75+
booktitle={Proceedings of the 37th international conference on Machine learning},
76+
pages={2729-2738},
77+
year={2020},
78+
organization={ACM}
79+
}
80+
81+
## License
82+
83+
The code is released under the MIT License. See the attached LICENSE file.

0 commit comments

Comments
 (0)