Skip to content

Commit 98c855a

Browse files
committed
add data.Dictionary & data.Graph.match
1 parent 6f92bc3 commit 98c855a

File tree

3 files changed

+433
-128
lines changed

3 files changed

+433
-128
lines changed

test/data/test_graph.py

Lines changed: 64 additions & 90 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,16 @@ def block_diag(self, tensors):
3636
y += num_col
3737
return result
3838

39+
def assert_equal(self, graph1, graph2, prompt):
40+
self.assertTrue(torch.equal(graph1.adjacency.to_dense(), graph2.adjacency.to_dense()),
41+
"Incorrect edge list in %s" % prompt)
42+
if hasattr(graph1, "node_feature") and hasattr(graph2, "node_feature"):
43+
self.assertTrue(torch.equal(graph1.node_feature, graph2.node_feature), "Incorrect feature in %s" % prompt)
44+
if hasattr(graph1, "edge_feature") and hasattr(graph2, "edge_feature"):
45+
self.assertTrue(torch.equal(graph1.edge_feature, graph2.edge_feature), "Incorrect feature in %s" % prompt)
46+
if hasattr(graph1, "graph_feature") and hasattr(graph2, "graph_feature"):
47+
self.assertTrue(torch.equal(graph1.graph_feature, graph2.graph_feature), "Incorrect feature in %s" % prompt)
48+
3949
def test_type_cast(self):
4050
dense_edge_feature = torch.zeros(self.num_node, self.num_node, self.num_feature)
4151
dense_edge_feature[tuple(self.edge_list.t())] = self.edge_feature
@@ -44,12 +54,8 @@ def test_type_cast(self):
4454
node_feature=self.node_feature.tolist(), edge_feature=self.edge_feature.tolist())
4555
graph2 = data.Graph(self.edge_list.numpy(), self.edge_weight.numpy(), self.num_node,
4656
node_feature=self.node_feature.numpy(), edge_feature=self.edge_feature.numpy())
47-
self.assertTrue(torch.equal(graph.edge_list, graph1.edge_list), "Incorrect type cast")
48-
self.assertTrue(torch.equal(graph.edge_feature, graph1.edge_feature), "Incorrect type cast")
49-
self.assertTrue(torch.equal(graph1.edge_list, graph2.edge_list), "Incorrect type cast")
50-
self.assertTrue(torch.equal(graph1.edge_weight, graph2.edge_weight), "Incorrect type cast")
51-
self.assertTrue(torch.equal(graph1.node_feature, graph2.node_feature), "Incorrect type cast")
52-
self.assertTrue(torch.equal(graph1.edge_feature, graph2.edge_feature), "Incorrect type cast")
57+
self.assert_equal(graph, graph1, "type cast")
58+
self.assert_equal(graph, graph2, "type cast")
5359

5460
def test_index(self):
5561
graph = data.Graph(self.edge_list, self.edge_weight, self.num_node,
@@ -58,7 +64,7 @@ def test_index(self):
5864
index = tuple(torch.randint(self.num_node, (2,)).tolist())
5965
result = graph[index]
6066
truth = self.adjacency[index]
61-
self.assertTrue(torch.equal(result, truth), "Incorrect index in single item")
67+
self.assertTrue(torch.equal(result, truth), "Incorrect edge in single item")
6268

6369
h_index = torch.randperm(self.num_node)[:self.num_node // 2]
6470
t_index = torch.randperm(self.num_node)[:self.num_node // 2]
@@ -71,7 +77,7 @@ def test_index(self):
7177
adj_truth[not_h_index, :] = 0
7278
adj_truth[:, not_t_index] = 0
7379
feat_truth = self.node_feature
74-
self.assertTrue(torch.equal(adj_result, adj_truth), "Incorrect index in node mask")
80+
self.assertTrue(torch.equal(adj_result, adj_truth), "Incorrect edge list in node mask")
7581
self.assertTrue(torch.equal(feat_result, feat_truth), "Incorrect feature in node mask")
7682

7783
new_graph = graph[:, 1: -1]
@@ -80,7 +86,7 @@ def test_index(self):
8086
adj_truth = torch.zeros_like(self.adjacency)
8187
adj_truth[:, 1: -1] = self.adjacency[:, 1: -1]
8288
feat_truth = self.node_feature
83-
self.assertTrue(torch.equal(adj_result, adj_truth), "Incorrect index in slice")
89+
self.assertTrue(torch.equal(adj_result, adj_truth), "Incorrect edge list in slice")
8490
self.assertTrue(torch.equal(feat_result, feat_truth), "Incorrect feature in slice")
8591

8692
index = torch.randperm(self.num_node)[:self.num_node // 2]
@@ -89,7 +95,7 @@ def test_index(self):
8995
feat_result = new_graph.node_feature
9096
adj_truth = self.adjacency[index][:, index]
9197
feat_truth = self.node_feature[index]
92-
self.assertTrue(torch.equal(adj_result, adj_truth), "Incorrect index in subgraph")
98+
self.assertTrue(torch.equal(adj_result, adj_truth), "Incorrect edge list in subgraph")
9399
self.assertTrue(torch.equal(feat_result, feat_truth), "Incorrect feature in subgraph")
94100

95101
def test_device(self):
@@ -100,11 +106,7 @@ def test_device(self):
100106
self.assertEqual(graph1.adjacency.device.type, "cuda", "Incorrect device")
101107
graph2 = graph1.cpu()
102108
self.assertEqual(graph2.adjacency.device.type, "cpu", "Incorrect device")
103-
self.assertTrue(torch.equal(graph.adjacency.to_dense(), graph2.adjacency.to_dense()),
104-
"Incorrect feature when changing device")
105-
self.assertTrue(torch.equal(graph.node_feature, graph2.node_feature), "Incorrect feature when changing device")
106-
self.assertTrue(torch.equal(graph.edge_feature, graph2.edge_feature), "Incorrect feature when changing device")
107-
self.assertTrue(torch.equal(graph.graph_feature, graph2.graph_feature), "Incorrect feature when changing device")
109+
self.assert_equal(graph, graph2, "device")
108110

109111
def test_pack(self):
110112
graph = data.Graph(self.edge_list, self.edge_weight, self.num_node,
@@ -128,26 +130,15 @@ def test_pack(self):
128130
edge_feat_truth = torch.cat([graph.edge_feature for graph in graphs])
129131
graph_feat_result = packed_graph.graph_feature
130132
graph_feat_truth = torch.stack([graph.graph_feature for graph in graphs])
131-
self.assertTrue(torch.equal(adj_result, adj_truth), "Incorrect index in pack")
133+
self.assertTrue(torch.equal(adj_result, adj_truth), "Incorrect edge list in pack")
132134
self.assertTrue(torch.equal(node_feat_result, node_feat_truth), "Incorrect feature in pack")
133135
self.assertTrue(torch.equal(edge_feat_result, edge_feat_truth), "Incorrect feature in pack")
134136
self.assertTrue(torch.equal(graph_feat_result, graph_feat_truth), "Incorrect feature in pack")
135137

136138
new_graphs = packed_graph.unpack()
137139
self.assertEqual(len(graphs), len(new_graphs), "Incorrect length in unpack")
138140
for graph, new_graph in zip(graphs, new_graphs):
139-
adj_truth = graph.adjacency.to_dense()
140-
adj_result = new_graph.adjacency.to_dense()
141-
node_feat_truth = graph.node_feature
142-
node_feat_result = new_graph.node_feature
143-
edge_feat_truth = graph.edge_feature
144-
edge_feat_result = new_graph.edge_feature
145-
graph_feat_truth = graph.graph_feature
146-
graph_feat_result = new_graph.graph_feature
147-
self.assertTrue(torch.equal(adj_result, adj_truth), "Incorrect index in unpack")
148-
self.assertTrue(torch.equal(node_feat_result, node_feat_truth), "Incorrect feature in unpack")
149-
self.assertTrue(torch.equal(edge_feat_result, edge_feat_truth), "Incorrect feature in unpack")
150-
self.assertTrue(torch.equal(graph_feat_result, graph_feat_truth), "Incorrect feature in unpack")
141+
self.assert_equal(graph, new_graph, "unpack")
151142

152143
graph = data.Graph(self.edge_list, self.edge_weight, self.num_node,
153144
node_feature=self.node_feature, edge_feature=self.edge_feature)
@@ -158,28 +149,12 @@ def test_pack(self):
158149
for start in range(4):
159150
mask[start * self.num_node + start: (start + 1) * self.num_node] = 1
160151
packed_graph2 = packed_graph2.subgraph(mask)
161-
adj_result = packed_graph2.adjacency.to_dense()
162-
adj_truth = packed_graph.adjacency.to_dense()
163-
node_feat_result = packed_graph2.node_feature
164-
node_feat_truth = packed_graph.node_feature
165-
edge_feat_result = packed_graph2.edge_feature
166-
edge_feat_truth = packed_graph.edge_feature
167-
self.assertTrue(torch.equal(adj_result, adj_truth), "Incorrect index in subgraph")
168-
self.assertTrue(torch.equal(node_feat_result, node_feat_truth), "Incorrect feature in subgraph")
169-
self.assertTrue(torch.equal(edge_feat_result, edge_feat_truth), "Incorrect feature in subgraph")
152+
self.assert_equal(packed_graph, packed_graph2, "subgraph")
170153

171154
packed_graph = data.Graph.pack(graphs[::2])
172155
packed_graph2 = data.Graph.pack(graphs)[::2]
173-
adj_result = packed_graph2.adjacency.to_dense()
174-
adj_truth = packed_graph.adjacency.to_dense()
175-
node_feat_result = packed_graph2.node_feature
176-
node_feat_truth = packed_graph.node_feature
177-
edge_feat_result = packed_graph2.edge_feature
178-
edge_feat_truth = packed_graph.edge_feature
179156
self.assertEqual(len(packed_graph), len(packed_graph2), "Incorrect batch size in graph mask")
180-
self.assertTrue(torch.equal(adj_result, adj_truth), "Incorrect index in graph mask")
181-
self.assertTrue(torch.equal(node_feat_result, node_feat_truth), "Incorrect feature in graph mask")
182-
self.assertTrue(torch.equal(edge_feat_result, edge_feat_truth), "Incorrect feature in graph mask")
157+
self.assert_equal(packed_graph, packed_graph2, "graph mask")
183158

184159
def test_reorder(self):
185160
graph = data.Graph(self.edge_list, self.edge_weight, self.num_node,
@@ -192,9 +167,9 @@ def test_reorder(self):
192167
node_feat_truth = graph.node_feature[order]
193168
edge_feat_result = new_graph.edge_feature
194169
edge_feat_truth = graph.edge_feature
195-
self.assertTrue(torch.equal(adj_result, adj_truth), "Incorrect node reorder")
196-
self.assertTrue(torch.equal(node_feat_result, node_feat_truth), "Incorrect node reorder")
197-
self.assertTrue(torch.equal(edge_feat_result, edge_feat_truth), "Incorrect node reorder")
170+
self.assertTrue(torch.equal(adj_result, adj_truth), "Incorrect edge list in node reorder")
171+
self.assertTrue(torch.equal(node_feat_result, node_feat_truth), "Incorrect feature in node reorder")
172+
self.assertTrue(torch.equal(edge_feat_result, edge_feat_truth), "Incorrect feature in node reorder")
198173

199174
order = torch.randperm(graph.num_edge)
200175
new_graph = graph.edge_mask(order)
@@ -204,9 +179,9 @@ def test_reorder(self):
204179
node_feat_truth = graph.node_feature
205180
edge_feat_result = new_graph.edge_feature
206181
edge_feat_truth = graph.edge_feature[order]
207-
self.assertTrue(torch.equal(edge_result, edge_truth), "Incorrect edge reorder")
208-
self.assertTrue(torch.equal(node_feat_result, node_feat_truth), "Incorrect edge reorder")
209-
self.assertTrue(torch.equal(edge_feat_result, edge_feat_truth), "Incorrect edge reorder")
182+
self.assertTrue(torch.equal(edge_result, edge_truth), "Incorrect edge list in edge reorder")
183+
self.assertTrue(torch.equal(node_feat_result, node_feat_truth), "Incorrect feature in edge reorder")
184+
self.assertTrue(torch.equal(edge_feat_result, edge_feat_truth), "Incorrect feature in edge reorder")
210185

211186
graphs = []
212187
for start in range(4):
@@ -216,30 +191,14 @@ def test_reorder(self):
216191
order = torch.randperm(4)
217192
packed_graph = packed_graph.subbatch(order)
218193
packed_graph2 = data.Graph.pack([graphs[i] for i in order])
219-
adj_result = packed_graph.adjacency.to_dense()
220-
adj_truth = packed_graph2.adjacency.to_dense()
221-
node_feat_result = packed_graph.node_feature
222-
node_feat_truth = packed_graph2.node_feature
223-
edge_feat_result = packed_graph.edge_feature
224-
edge_feat_truth = packed_graph2.edge_feature
225-
self.assertTrue(torch.equal(adj_result, adj_truth), "Incorrect graph reorder")
226-
self.assertTrue(torch.equal(node_feat_result, node_feat_truth), "Incorrect graph reorder")
227-
self.assertTrue(torch.equal(edge_feat_result, edge_feat_truth), "Incorrect graph reorder")
194+
self.assert_equal(packed_graph, packed_graph2, "graph reorder")
228195

229196
def test_repeat(self):
230197
graph = data.Graph(self.edge_list, self.edge_weight, self.num_node,
231198
node_feature=self.node_feature, edge_feature=self.edge_feature)
232199
repeat_graph = graph.repeat(5)
233200
true_graph = data.Graph.pack([graph] * 5)
234-
adj_result = repeat_graph.adjacency.to_dense()
235-
adj_truth = true_graph.adjacency.to_dense()
236-
node_feat_result = repeat_graph.node_feature
237-
node_feat_truth = true_graph.node_feature
238-
edge_feat_result = repeat_graph.edge_feature
239-
edge_feat_truth = true_graph.edge_feature
240-
self.assertTrue(torch.equal(adj_result, adj_truth), "Incorrect index in repeat")
241-
self.assertTrue(torch.equal(node_feat_result, node_feat_truth), "Incorrect feature in repeat")
242-
self.assertTrue(torch.equal(edge_feat_result, edge_feat_truth), "Incorrect feature in repeat")
201+
self.assert_equal(repeat_graph, true_graph, "repeat")
243202

244203
# special case: graphs with no edges
245204
graphs = [graph.edge_mask([]), graph.edge_mask([])]
@@ -249,15 +208,7 @@ def test_repeat(self):
249208
packed_graph = data.Graph.pack(graphs)
250209
repeat_graph = packed_graph.repeat(5)
251210
true_graph = data.Graph.pack(graphs * 5)
252-
adj_result = repeat_graph.adjacency.to_dense()
253-
adj_truth = true_graph.adjacency.to_dense()
254-
node_feat_result = repeat_graph.node_feature
255-
node_feat_truth = true_graph.node_feature
256-
edge_feat_result = repeat_graph.edge_feature
257-
edge_feat_truth = true_graph.edge_feature
258-
self.assertTrue(torch.equal(adj_result, adj_truth), "Incorrect index in repeat")
259-
self.assertTrue(torch.equal(node_feat_result, node_feat_truth), "Incorrect feature in repeat")
260-
self.assertTrue(torch.equal(edge_feat_result, edge_feat_truth), "Incorrect feature in repeat")
211+
self.assert_equal(repeat_graph, true_graph, "repeat")
261212

262213
def test_repeat_interleave(self):
263214
graph = data.Graph(self.edge_list, self.edge_weight, self.num_node,
@@ -275,15 +226,20 @@ def test_repeat_interleave(self):
275226
for i, graph in zip(repeats, graphs):
276227
true_graphs += [graph] * i
277228
true_graph = data.Graph.pack(true_graphs)
278-
adj_result = repeat_graph.adjacency.to_dense()
279-
adj_truth = true_graph.adjacency.to_dense()
280-
node_feat_result = repeat_graph.node_feature
281-
node_feat_truth = true_graph.node_feature
282-
edge_feat_result = repeat_graph.edge_feature
283-
edge_feat_truth = true_graph.edge_feature
284-
self.assertTrue(torch.equal(adj_result, adj_truth), "Incorrect index in repeat_interleave")
285-
self.assertTrue(torch.equal(node_feat_result, node_feat_truth), "Incorrect feature in repeat_interleave")
286-
self.assertTrue(torch.equal(edge_feat_result, edge_feat_truth), "Incorrect feature in repeat_interleave")
229+
self.assert_equal(repeat_graph, true_graph, "repeat interleave")
230+
231+
def test_repeated_index(self):
232+
graph = data.Graph(self.edge_list, self.edge_weight, self.num_node)
233+
graphs = []
234+
for start in range(4):
235+
index = torch.arange(start, self.num_node)
236+
graphs.append(graph.subgraph(index))
237+
packed_graph = data.Graph.pack(graphs)
238+
# special case: some indexes missing, not sorted
239+
index = [1, 0, 2, 1, 0]
240+
packed_graph = packed_graph[index]
241+
packed_graph2 = data.Graph.pack([graphs[i] for i in index])
242+
self.assert_equal(packed_graph, packed_graph2, "repeated index")
287243

288244
def test_split(self):
289245
graph = data.Graph(self.edge_list, self.edge_weight, self.num_node)
@@ -352,11 +308,29 @@ def test_directed(self):
352308
graph = digraph.undirected()
353309
adj_result = graph.adjacency.to_dense()
354310
adj_truth = (digraph.adjacency + digraph.adjacency.t()).to_dense()
355-
self.assertTrue(torch.equal(adj_result, adj_truth), "Incorrect undirected graph")
311+
self.assertTrue(torch.equal(adj_result, adj_truth), "Incorrect conversion from directed to undirected")
356312
digraph2 = graph.directed()
357313
adj_result = digraph2.adjacency.to_dense()
358314
adj_truth = adj_truth.triu()
359-
self.assertTrue(torch.equal(adj_result, adj_truth), "Incorrect directed graph")
315+
self.assertTrue(torch.equal(adj_result, adj_truth), "Incorrect conversion from undirected to directed")
316+
317+
def test_match(self):
318+
graph = data.Graph(self.edge_list, self.edge_weight, self.num_node)
319+
index = torch.randperm(graph.num_edge)[:self.num_node]
320+
edge = graph.edge_list[index]
321+
mask = torch.randint(2, (len(edge), 1))
322+
edge.scatter_(1, mask, -1)
323+
random = torch.randint_like(edge, self.num_node)
324+
edge = torch.cat([edge, random])
325+
index_result, num_match_result = graph.match(edge)
326+
index_results = index_result.split(num_match_result.tolist())
327+
match = ((graph.edge_list.unsqueeze(0) == edge.unsqueeze(1)) | (edge.unsqueeze(1) == -1)).all(dim=-1)
328+
query_index, index_truth = match.nonzero().t()
329+
num_match_truth = torch.bincount(query_index, minlength=len(edge))
330+
index_truths = index_truth.split(num_match_truth.tolist())
331+
self.assertTrue(torch.equal(num_match_result, num_match_truth), "Incorrect edge match")
332+
for index_result, index_truth in zip(index_results, index_truths):
333+
self.assertTrue(torch.equal(index_result.sort()[0], index_truth.sort()[0]), "Incorrect edge match")
360334

361335

362336
if __name__ == "__main__":

0 commit comments

Comments
 (0)