@@ -36,6 +36,16 @@ def block_diag(self, tensors):
3636 y += num_col
3737 return result
3838
39+ def assert_equal (self , graph1 , graph2 , prompt ):
40+ self .assertTrue (torch .equal (graph1 .adjacency .to_dense (), graph2 .adjacency .to_dense ()),
41+ "Incorrect edge list in %s" % prompt )
42+ if hasattr (graph1 , "node_feature" ) and hasattr (graph2 , "node_feature" ):
43+ self .assertTrue (torch .equal (graph1 .node_feature , graph2 .node_feature ), "Incorrect feature in %s" % prompt )
44+ if hasattr (graph1 , "edge_feature" ) and hasattr (graph2 , "edge_feature" ):
45+ self .assertTrue (torch .equal (graph1 .edge_feature , graph2 .edge_feature ), "Incorrect feature in %s" % prompt )
46+ if hasattr (graph1 , "graph_feature" ) and hasattr (graph2 , "graph_feature" ):
47+ self .assertTrue (torch .equal (graph1 .graph_feature , graph2 .graph_feature ), "Incorrect feature in %s" % prompt )
48+
3949 def test_type_cast (self ):
4050 dense_edge_feature = torch .zeros (self .num_node , self .num_node , self .num_feature )
4151 dense_edge_feature [tuple (self .edge_list .t ())] = self .edge_feature
@@ -44,12 +54,8 @@ def test_type_cast(self):
4454 node_feature = self .node_feature .tolist (), edge_feature = self .edge_feature .tolist ())
4555 graph2 = data .Graph (self .edge_list .numpy (), self .edge_weight .numpy (), self .num_node ,
4656 node_feature = self .node_feature .numpy (), edge_feature = self .edge_feature .numpy ())
47- self .assertTrue (torch .equal (graph .edge_list , graph1 .edge_list ), "Incorrect type cast" )
48- self .assertTrue (torch .equal (graph .edge_feature , graph1 .edge_feature ), "Incorrect type cast" )
49- self .assertTrue (torch .equal (graph1 .edge_list , graph2 .edge_list ), "Incorrect type cast" )
50- self .assertTrue (torch .equal (graph1 .edge_weight , graph2 .edge_weight ), "Incorrect type cast" )
51- self .assertTrue (torch .equal (graph1 .node_feature , graph2 .node_feature ), "Incorrect type cast" )
52- self .assertTrue (torch .equal (graph1 .edge_feature , graph2 .edge_feature ), "Incorrect type cast" )
57+ self .assert_equal (graph , graph1 , "type cast" )
58+ self .assert_equal (graph , graph2 , "type cast" )
5359
5460 def test_index (self ):
5561 graph = data .Graph (self .edge_list , self .edge_weight , self .num_node ,
@@ -58,7 +64,7 @@ def test_index(self):
5864 index = tuple (torch .randint (self .num_node , (2 ,)).tolist ())
5965 result = graph [index ]
6066 truth = self .adjacency [index ]
61- self .assertTrue (torch .equal (result , truth ), "Incorrect index in single item" )
67+ self .assertTrue (torch .equal (result , truth ), "Incorrect edge in single item" )
6268
6369 h_index = torch .randperm (self .num_node )[:self .num_node // 2 ]
6470 t_index = torch .randperm (self .num_node )[:self .num_node // 2 ]
@@ -71,7 +77,7 @@ def test_index(self):
7177 adj_truth [not_h_index , :] = 0
7278 adj_truth [:, not_t_index ] = 0
7379 feat_truth = self .node_feature
74- self .assertTrue (torch .equal (adj_result , adj_truth ), "Incorrect index in node mask" )
80+ self .assertTrue (torch .equal (adj_result , adj_truth ), "Incorrect edge list in node mask" )
7581 self .assertTrue (torch .equal (feat_result , feat_truth ), "Incorrect feature in node mask" )
7682
7783 new_graph = graph [:, 1 : - 1 ]
@@ -80,7 +86,7 @@ def test_index(self):
8086 adj_truth = torch .zeros_like (self .adjacency )
8187 adj_truth [:, 1 : - 1 ] = self .adjacency [:, 1 : - 1 ]
8288 feat_truth = self .node_feature
83- self .assertTrue (torch .equal (adj_result , adj_truth ), "Incorrect index in slice" )
89+ self .assertTrue (torch .equal (adj_result , adj_truth ), "Incorrect edge list in slice" )
8490 self .assertTrue (torch .equal (feat_result , feat_truth ), "Incorrect feature in slice" )
8591
8692 index = torch .randperm (self .num_node )[:self .num_node // 2 ]
@@ -89,7 +95,7 @@ def test_index(self):
8995 feat_result = new_graph .node_feature
9096 adj_truth = self .adjacency [index ][:, index ]
9197 feat_truth = self .node_feature [index ]
92- self .assertTrue (torch .equal (adj_result , adj_truth ), "Incorrect index in subgraph" )
98+ self .assertTrue (torch .equal (adj_result , adj_truth ), "Incorrect edge list in subgraph" )
9399 self .assertTrue (torch .equal (feat_result , feat_truth ), "Incorrect feature in subgraph" )
94100
95101 def test_device (self ):
@@ -100,11 +106,7 @@ def test_device(self):
100106 self .assertEqual (graph1 .adjacency .device .type , "cuda" , "Incorrect device" )
101107 graph2 = graph1 .cpu ()
102108 self .assertEqual (graph2 .adjacency .device .type , "cpu" , "Incorrect device" )
103- self .assertTrue (torch .equal (graph .adjacency .to_dense (), graph2 .adjacency .to_dense ()),
104- "Incorrect feature when changing device" )
105- self .assertTrue (torch .equal (graph .node_feature , graph2 .node_feature ), "Incorrect feature when changing device" )
106- self .assertTrue (torch .equal (graph .edge_feature , graph2 .edge_feature ), "Incorrect feature when changing device" )
107- self .assertTrue (torch .equal (graph .graph_feature , graph2 .graph_feature ), "Incorrect feature when changing device" )
109+ self .assert_equal (graph , graph2 , "device" )
108110
109111 def test_pack (self ):
110112 graph = data .Graph (self .edge_list , self .edge_weight , self .num_node ,
@@ -128,26 +130,15 @@ def test_pack(self):
128130 edge_feat_truth = torch .cat ([graph .edge_feature for graph in graphs ])
129131 graph_feat_result = packed_graph .graph_feature
130132 graph_feat_truth = torch .stack ([graph .graph_feature for graph in graphs ])
131- self .assertTrue (torch .equal (adj_result , adj_truth ), "Incorrect index in pack" )
133+ self .assertTrue (torch .equal (adj_result , adj_truth ), "Incorrect edge list in pack" )
132134 self .assertTrue (torch .equal (node_feat_result , node_feat_truth ), "Incorrect feature in pack" )
133135 self .assertTrue (torch .equal (edge_feat_result , edge_feat_truth ), "Incorrect feature in pack" )
134136 self .assertTrue (torch .equal (graph_feat_result , graph_feat_truth ), "Incorrect feature in pack" )
135137
136138 new_graphs = packed_graph .unpack ()
137139 self .assertEqual (len (graphs ), len (new_graphs ), "Incorrect length in unpack" )
138140 for graph , new_graph in zip (graphs , new_graphs ):
139- adj_truth = graph .adjacency .to_dense ()
140- adj_result = new_graph .adjacency .to_dense ()
141- node_feat_truth = graph .node_feature
142- node_feat_result = new_graph .node_feature
143- edge_feat_truth = graph .edge_feature
144- edge_feat_result = new_graph .edge_feature
145- graph_feat_truth = graph .graph_feature
146- graph_feat_result = new_graph .graph_feature
147- self .assertTrue (torch .equal (adj_result , adj_truth ), "Incorrect index in unpack" )
148- self .assertTrue (torch .equal (node_feat_result , node_feat_truth ), "Incorrect feature in unpack" )
149- self .assertTrue (torch .equal (edge_feat_result , edge_feat_truth ), "Incorrect feature in unpack" )
150- self .assertTrue (torch .equal (graph_feat_result , graph_feat_truth ), "Incorrect feature in unpack" )
141+ self .assert_equal (graph , new_graph , "unpack" )
151142
152143 graph = data .Graph (self .edge_list , self .edge_weight , self .num_node ,
153144 node_feature = self .node_feature , edge_feature = self .edge_feature )
@@ -158,28 +149,12 @@ def test_pack(self):
158149 for start in range (4 ):
159150 mask [start * self .num_node + start : (start + 1 ) * self .num_node ] = 1
160151 packed_graph2 = packed_graph2 .subgraph (mask )
161- adj_result = packed_graph2 .adjacency .to_dense ()
162- adj_truth = packed_graph .adjacency .to_dense ()
163- node_feat_result = packed_graph2 .node_feature
164- node_feat_truth = packed_graph .node_feature
165- edge_feat_result = packed_graph2 .edge_feature
166- edge_feat_truth = packed_graph .edge_feature
167- self .assertTrue (torch .equal (adj_result , adj_truth ), "Incorrect index in subgraph" )
168- self .assertTrue (torch .equal (node_feat_result , node_feat_truth ), "Incorrect feature in subgraph" )
169- self .assertTrue (torch .equal (edge_feat_result , edge_feat_truth ), "Incorrect feature in subgraph" )
152+ self .assert_equal (packed_graph , packed_graph2 , "subgraph" )
170153
171154 packed_graph = data .Graph .pack (graphs [::2 ])
172155 packed_graph2 = data .Graph .pack (graphs )[::2 ]
173- adj_result = packed_graph2 .adjacency .to_dense ()
174- adj_truth = packed_graph .adjacency .to_dense ()
175- node_feat_result = packed_graph2 .node_feature
176- node_feat_truth = packed_graph .node_feature
177- edge_feat_result = packed_graph2 .edge_feature
178- edge_feat_truth = packed_graph .edge_feature
179156 self .assertEqual (len (packed_graph ), len (packed_graph2 ), "Incorrect batch size in graph mask" )
180- self .assertTrue (torch .equal (adj_result , adj_truth ), "Incorrect index in graph mask" )
181- self .assertTrue (torch .equal (node_feat_result , node_feat_truth ), "Incorrect feature in graph mask" )
182- self .assertTrue (torch .equal (edge_feat_result , edge_feat_truth ), "Incorrect feature in graph mask" )
157+ self .assert_equal (packed_graph , packed_graph2 , "graph mask" )
183158
184159 def test_reorder (self ):
185160 graph = data .Graph (self .edge_list , self .edge_weight , self .num_node ,
@@ -192,9 +167,9 @@ def test_reorder(self):
192167 node_feat_truth = graph .node_feature [order ]
193168 edge_feat_result = new_graph .edge_feature
194169 edge_feat_truth = graph .edge_feature
195- self .assertTrue (torch .equal (adj_result , adj_truth ), "Incorrect node reorder" )
196- self .assertTrue (torch .equal (node_feat_result , node_feat_truth ), "Incorrect node reorder" )
197- self .assertTrue (torch .equal (edge_feat_result , edge_feat_truth ), "Incorrect node reorder" )
170+ self .assertTrue (torch .equal (adj_result , adj_truth ), "Incorrect edge list in node reorder" )
171+ self .assertTrue (torch .equal (node_feat_result , node_feat_truth ), "Incorrect feature in node reorder" )
172+ self .assertTrue (torch .equal (edge_feat_result , edge_feat_truth ), "Incorrect feature in node reorder" )
198173
199174 order = torch .randperm (graph .num_edge )
200175 new_graph = graph .edge_mask (order )
@@ -204,9 +179,9 @@ def test_reorder(self):
204179 node_feat_truth = graph .node_feature
205180 edge_feat_result = new_graph .edge_feature
206181 edge_feat_truth = graph .edge_feature [order ]
207- self .assertTrue (torch .equal (edge_result , edge_truth ), "Incorrect edge reorder" )
208- self .assertTrue (torch .equal (node_feat_result , node_feat_truth ), "Incorrect edge reorder" )
209- self .assertTrue (torch .equal (edge_feat_result , edge_feat_truth ), "Incorrect edge reorder" )
182+ self .assertTrue (torch .equal (edge_result , edge_truth ), "Incorrect edge list in edge reorder" )
183+ self .assertTrue (torch .equal (node_feat_result , node_feat_truth ), "Incorrect feature in edge reorder" )
184+ self .assertTrue (torch .equal (edge_feat_result , edge_feat_truth ), "Incorrect feature in edge reorder" )
210185
211186 graphs = []
212187 for start in range (4 ):
@@ -216,30 +191,14 @@ def test_reorder(self):
216191 order = torch .randperm (4 )
217192 packed_graph = packed_graph .subbatch (order )
218193 packed_graph2 = data .Graph .pack ([graphs [i ] for i in order ])
219- adj_result = packed_graph .adjacency .to_dense ()
220- adj_truth = packed_graph2 .adjacency .to_dense ()
221- node_feat_result = packed_graph .node_feature
222- node_feat_truth = packed_graph2 .node_feature
223- edge_feat_result = packed_graph .edge_feature
224- edge_feat_truth = packed_graph2 .edge_feature
225- self .assertTrue (torch .equal (adj_result , adj_truth ), "Incorrect graph reorder" )
226- self .assertTrue (torch .equal (node_feat_result , node_feat_truth ), "Incorrect graph reorder" )
227- self .assertTrue (torch .equal (edge_feat_result , edge_feat_truth ), "Incorrect graph reorder" )
194+ self .assert_equal (packed_graph , packed_graph2 , "graph reorder" )
228195
229196 def test_repeat (self ):
230197 graph = data .Graph (self .edge_list , self .edge_weight , self .num_node ,
231198 node_feature = self .node_feature , edge_feature = self .edge_feature )
232199 repeat_graph = graph .repeat (5 )
233200 true_graph = data .Graph .pack ([graph ] * 5 )
234- adj_result = repeat_graph .adjacency .to_dense ()
235- adj_truth = true_graph .adjacency .to_dense ()
236- node_feat_result = repeat_graph .node_feature
237- node_feat_truth = true_graph .node_feature
238- edge_feat_result = repeat_graph .edge_feature
239- edge_feat_truth = true_graph .edge_feature
240- self .assertTrue (torch .equal (adj_result , adj_truth ), "Incorrect index in repeat" )
241- self .assertTrue (torch .equal (node_feat_result , node_feat_truth ), "Incorrect feature in repeat" )
242- self .assertTrue (torch .equal (edge_feat_result , edge_feat_truth ), "Incorrect feature in repeat" )
201+ self .assert_equal (repeat_graph , true_graph , "repeat" )
243202
244203 # special case: graphs with no edges
245204 graphs = [graph .edge_mask ([]), graph .edge_mask ([])]
@@ -249,15 +208,7 @@ def test_repeat(self):
249208 packed_graph = data .Graph .pack (graphs )
250209 repeat_graph = packed_graph .repeat (5 )
251210 true_graph = data .Graph .pack (graphs * 5 )
252- adj_result = repeat_graph .adjacency .to_dense ()
253- adj_truth = true_graph .adjacency .to_dense ()
254- node_feat_result = repeat_graph .node_feature
255- node_feat_truth = true_graph .node_feature
256- edge_feat_result = repeat_graph .edge_feature
257- edge_feat_truth = true_graph .edge_feature
258- self .assertTrue (torch .equal (adj_result , adj_truth ), "Incorrect index in repeat" )
259- self .assertTrue (torch .equal (node_feat_result , node_feat_truth ), "Incorrect feature in repeat" )
260- self .assertTrue (torch .equal (edge_feat_result , edge_feat_truth ), "Incorrect feature in repeat" )
211+ self .assert_equal (repeat_graph , true_graph , "repeat" )
261212
262213 def test_repeat_interleave (self ):
263214 graph = data .Graph (self .edge_list , self .edge_weight , self .num_node ,
@@ -275,15 +226,20 @@ def test_repeat_interleave(self):
275226 for i , graph in zip (repeats , graphs ):
276227 true_graphs += [graph ] * i
277228 true_graph = data .Graph .pack (true_graphs )
278- adj_result = repeat_graph .adjacency .to_dense ()
279- adj_truth = true_graph .adjacency .to_dense ()
280- node_feat_result = repeat_graph .node_feature
281- node_feat_truth = true_graph .node_feature
282- edge_feat_result = repeat_graph .edge_feature
283- edge_feat_truth = true_graph .edge_feature
284- self .assertTrue (torch .equal (adj_result , adj_truth ), "Incorrect index in repeat_interleave" )
285- self .assertTrue (torch .equal (node_feat_result , node_feat_truth ), "Incorrect feature in repeat_interleave" )
286- self .assertTrue (torch .equal (edge_feat_result , edge_feat_truth ), "Incorrect feature in repeat_interleave" )
229+ self .assert_equal (repeat_graph , true_graph , "repeat interleave" )
230+
231+ def test_repeated_index (self ):
232+ graph = data .Graph (self .edge_list , self .edge_weight , self .num_node )
233+ graphs = []
234+ for start in range (4 ):
235+ index = torch .arange (start , self .num_node )
236+ graphs .append (graph .subgraph (index ))
237+ packed_graph = data .Graph .pack (graphs )
238+ # special case: some indexes missing, not sorted
239+ index = [1 , 0 , 2 , 1 , 0 ]
240+ packed_graph = packed_graph [index ]
241+ packed_graph2 = data .Graph .pack ([graphs [i ] for i in index ])
242+ self .assert_equal (packed_graph , packed_graph2 , "repeated index" )
287243
288244 def test_split (self ):
289245 graph = data .Graph (self .edge_list , self .edge_weight , self .num_node )
@@ -352,11 +308,29 @@ def test_directed(self):
352308 graph = digraph .undirected ()
353309 adj_result = graph .adjacency .to_dense ()
354310 adj_truth = (digraph .adjacency + digraph .adjacency .t ()).to_dense ()
355- self .assertTrue (torch .equal (adj_result , adj_truth ), "Incorrect undirected graph " )
311+ self .assertTrue (torch .equal (adj_result , adj_truth ), "Incorrect conversion from directed to undirected " )
356312 digraph2 = graph .directed ()
357313 adj_result = digraph2 .adjacency .to_dense ()
358314 adj_truth = adj_truth .triu ()
359- self .assertTrue (torch .equal (adj_result , adj_truth ), "Incorrect directed graph" )
315+ self .assertTrue (torch .equal (adj_result , adj_truth ), "Incorrect conversion from undirected to directed" )
316+
317+ def test_match (self ):
318+ graph = data .Graph (self .edge_list , self .edge_weight , self .num_node )
319+ index = torch .randperm (graph .num_edge )[:self .num_node ]
320+ edge = graph .edge_list [index ]
321+ mask = torch .randint (2 , (len (edge ), 1 ))
322+ edge .scatter_ (1 , mask , - 1 )
323+ random = torch .randint_like (edge , self .num_node )
324+ edge = torch .cat ([edge , random ])
325+ index_result , num_match_result = graph .match (edge )
326+ index_results = index_result .split (num_match_result .tolist ())
327+ match = ((graph .edge_list .unsqueeze (0 ) == edge .unsqueeze (1 )) | (edge .unsqueeze (1 ) == - 1 )).all (dim = - 1 )
328+ query_index , index_truth = match .nonzero ().t ()
329+ num_match_truth = torch .bincount (query_index , minlength = len (edge ))
330+ index_truths = index_truth .split (num_match_truth .tolist ())
331+ self .assertTrue (torch .equal (num_match_result , num_match_truth ), "Incorrect edge match" )
332+ for index_result , index_truth in zip (index_results , index_truths ):
333+ self .assertTrue (torch .equal (index_result .sort ()[0 ], index_truth .sort ()[0 ]), "Incorrect edge match" )
360334
361335
362336if __name__ == "__main__" :
0 commit comments