44from torch .utils import data as torch_data
55
66from torchdrug import core , tasks
7+ from torchdrug .layers import functional
78from torchdrug .core import Registry as R
89
910
@@ -133,13 +134,13 @@ def predict(self, batch, all_loss=None, metric=None):
133134 for neg_index in all_index .split (self .num_negative ):
134135 r_index = pos_r_index .unsqueeze (- 1 ).expand (- 1 , len (neg_index ))
135136 h_index , t_index = torch .meshgrid (pos_h_index , neg_index )
136- t_pred = self .model (self .fact_graph , h_index , t_index , r_index )
137+ t_pred = self .model (self .fact_graph , h_index , t_index , r_index , all_loss = all_loss , metric = metric )
137138 t_preds .append (t_pred )
138139 t_pred = torch .cat (t_preds , dim = - 1 )
139140 for neg_index in all_index .split (self .num_negative ):
140141 r_index = pos_r_index .unsqueeze (- 1 ).expand (- 1 , len (neg_index ))
141142 t_index , h_index = torch .meshgrid (pos_t_index , neg_index )
142- h_pred = self .model (self .fact_graph , h_index , t_index , r_index )
143+ h_pred = self .model (self .fact_graph , h_index , t_index , r_index , all_loss = all_loss , metric = metric )
143144 h_preds .append (h_pred )
144145 h_pred = torch .cat (h_preds , dim = - 1 )
145146 pred = torch .stack ([t_pred , h_pred ], dim = 1 )
@@ -162,32 +163,26 @@ def predict(self, batch, all_loss=None, metric=None):
162163
163164 def target (self , batch ):
164165 # test target
166+ batch_size = len (batch )
165167 pos_h_index , pos_t_index , pos_r_index = batch .t ()
166- target = torch .stack ([pos_t_index , pos_h_index ], dim = 1 )
168+ any = - torch .ones_like (pos_h_index )
169+
170+ pattern = torch .stack ([pos_h_index , any , pos_r_index ], dim = - 1 )
171+ edge_index , num_t_truth = self .graph .match (pattern )
172+ t_truth_index = self .graph .edge_list [edge_index , 1 ]
173+ pos_index = functional ._size_to_index (num_t_truth )
174+ t_mask = torch .ones (batch_size , self .num_entity , dtype = torch .bool , device = self .device )
175+ t_mask [pos_index , t_truth_index ] = 0
176+
177+ pattern = torch .stack ([any , pos_t_index , pos_r_index ], dim = - 1 )
178+ edge_index , num_h_truth = self .graph .match (pattern )
179+ h_truth_index = self .graph .edge_list [edge_index , 0 ]
180+ pos_index = functional ._size_to_index (num_h_truth )
181+ h_mask = torch .ones (batch_size , self .num_entity , dtype = torch .bool , device = self .device )
182+ h_mask [pos_index , h_truth_index ] = 0
167183
168- hr_index_set , hr_inverse = torch .unique (pos_h_index * self .num_relation + pos_r_index , return_inverse = True )
169- tr_index_set , tr_inverse = torch .unique (pos_t_index * self .num_relation + pos_r_index , return_inverse = True )
170- hr_index2id = - torch .ones (self .num_entity * self .num_relation , dtype = torch .long , device = self .device )
171- hr_index2id [hr_index_set ] = torch .arange (len (hr_index_set ), device = self .device )
172- tr_index2id = - torch .ones (self .num_entity * self .num_relation , dtype = torch .long , device = self .device )
173- tr_index2id [tr_index_set ] = torch .arange (len (tr_index_set ), device = self .device )
174-
175- h_index , t_index , r_index = self .graph .edge_list .t ()
176- hr_index = h_index * self .num_relation + r_index
177- tr_index = t_index * self .num_relation + r_index
178- valid = hr_index2id [hr_index ] >= 0
179- hr_index = hr_index [valid ]
180- t_index = t_index [valid ]
181- t_mask_set = torch .ones (len (hr_index_set ), self .num_entity , dtype = torch .bool , device = self .device )
182- t_mask_set [hr_index2id [hr_index ], t_index ] = 0
183- t_mask = t_mask_set [hr_inverse ]
184- valid = tr_index2id [tr_index ] >= 0
185- tr_index = tr_index [valid ]
186- h_index = h_index [valid ]
187- h_mask_set = torch .ones (len (tr_index_set ), self .num_entity , dtype = torch .bool , device = self .device )
188- h_mask_set [tr_index2id [tr_index ], h_index ] = 0
189- h_mask = h_mask_set [tr_inverse ]
190184 mask = torch .stack ([t_mask , h_mask ], dim = 1 )
185+ target = torch .stack ([pos_t_index , pos_h_index ], dim = 1 )
191186
192187 # in case of GPU OOM
193188 return mask .cpu (), target .cpu ()
@@ -225,47 +220,30 @@ def visualize(self, batch):
225220 @torch .no_grad ()
226221 def _strict_negative (self , pos_h_index , pos_t_index , pos_r_index ):
227222 batch_size = len (pos_h_index )
228-
229- hr_index_set , hr_inverse = torch .unique (pos_h_index * self .num_relation + pos_r_index , return_inverse = True )
230- tr_index_set , tr_inverse = torch .unique (pos_t_index * self .num_relation + pos_r_index , return_inverse = True )
231- hr_index2id = - torch .ones (self .num_entity * self .num_relation , dtype = torch .long , device = self .device )
232- hr_index2id [hr_index_set ] = torch .arange (len (hr_index_set ), device = self .device )
233- tr_index2id = - torch .ones (self .num_entity * self .num_relation , dtype = torch .long , device = self .device )
234- tr_index2id [tr_index_set ] = torch .arange (len (tr_index_set ), device = self .device )
235-
236- h_index , t_index , r_index = self .fact_graph .edge_list .t ()
237- hr_index = h_index * self .num_relation + r_index
238- tr_index = t_index * self .num_relation + r_index
239- valid = hr_index2id [hr_index ] >= 0
240- hr_index = hr_index [valid ]
241- t_index = t_index [valid ]
242- t_mask_set = torch .ones (len (hr_index_set ), self .num_entity , dtype = torch .bool , device = self .device )
243- t_mask_set [hr_index2id [hr_index ], t_index ] = 0
244- t_mask = t_mask_set [hr_inverse ]
245- valid = tr_index2id [tr_index ] >= 0
246- tr_index = tr_index [valid ]
247- h_index = h_index [valid ]
248- h_mask_set = torch .ones (len (tr_index_set ), self .num_entity , dtype = torch .bool , device = self .device )
249- h_mask_set [tr_index2id [tr_index ], h_index ] = 0
250- h_mask = h_mask_set [tr_inverse ]
251-
252- num_neg_t = t_mask .sum (dim = - 1 , keepdim = True )
253- num_neg_h = h_mask .sum (dim = - 1 , keepdim = True )
254- num_cum_neg_t = num_neg_t .cumsum (0 )
255- num_cum_neg_h = num_neg_h .cumsum (0 )
256-
257- neg_t_index = t_mask .nonzero ()[:, 1 ]
258- neg_h_index = h_mask .nonzero ()[:, 1 ]
259-
260- rand = torch .rand (batch_size , self .num_negative , device = self .device )
261- index = (rand [:batch_size // 2 ] * num_neg_t [:batch_size // 2 ]).long ()
262- index = index + (num_cum_neg_t [:batch_size // 2 ] - num_neg_t [:batch_size // 2 ])
263- neg_index_t = neg_t_index [index ]
264-
265- index = (rand [batch_size // 2 :] * num_neg_h [batch_size // 2 :]).long ()
266- index = index + (num_cum_neg_h [batch_size // 2 :] - num_neg_h [batch_size // 2 :])
267- neg_index_h = neg_h_index [index ]
268-
269- neg_index = torch .cat ([neg_index_t , neg_index_h ])
223+ any = - torch .ones_like (pos_h_index )
224+
225+ pattern = torch .stack ([pos_h_index , any , pos_r_index ], dim = - 1 )
226+ pattern = pattern [:batch_size // 2 ]
227+ edge_index , num_t_truth = self .fact_graph .match (pattern )
228+ t_truth_index = self .fact_graph .edge_list [edge_index , 1 ]
229+ pos_index = functional ._size_to_index (num_t_truth )
230+ t_mask = torch .ones (len (pattern ), self .num_entity , dtype = torch .bool , device = self .device )
231+ t_mask [pos_index , t_truth_index ] = 0
232+ neg_t_candidate = t_mask .nonzero ()[:, 1 ]
233+ num_t_candidate = t_mask .sum (dim = - 1 )
234+ neg_t_index = functional .variadic_sample (neg_t_candidate , num_t_candidate , self .num_negative )
235+
236+ pattern = torch .stack ([any , pos_t_index , pos_r_index ], dim = - 1 )
237+ pattern = pattern [batch_size // 2 :]
238+ edge_index , num_h_truth = self .fact_graph .match (pattern )
239+ h_truth_index = self .fact_graph .edge_list [edge_index , 0 ]
240+ pos_index = functional ._size_to_index (num_h_truth )
241+ h_mask = torch .ones (len (pattern ), self .num_entity , dtype = torch .bool , device = self .device )
242+ h_mask [pos_index , h_truth_index ] = 0
243+ neg_h_candidate = h_mask .nonzero ()[:, 1 ]
244+ num_h_candidate = h_mask .sum (dim = - 1 )
245+ neg_h_index = functional .variadic_sample (neg_h_candidate , num_h_candidate , self .num_negative )
246+
247+ neg_index = torch .cat ([neg_t_index , neg_h_index ])
270248
271249 return neg_index
0 commit comments