Skip to content

Commit e9c02e1

Browse files
committed
simplify KnowledgeGraphCompletion implementation
1 parent 98c855a commit e9c02e1

File tree

2 files changed

+48
-70
lines changed

2 files changed

+48
-70
lines changed

torchdrug/tasks/reasoning.py

Lines changed: 45 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from torch.utils import data as torch_data
55

66
from torchdrug import core, tasks
7+
from torchdrug.layers import functional
78
from torchdrug.core import Registry as R
89

910

@@ -133,13 +134,13 @@ def predict(self, batch, all_loss=None, metric=None):
133134
for neg_index in all_index.split(self.num_negative):
134135
r_index = pos_r_index.unsqueeze(-1).expand(-1, len(neg_index))
135136
h_index, t_index = torch.meshgrid(pos_h_index, neg_index)
136-
t_pred = self.model(self.fact_graph, h_index, t_index, r_index)
137+
t_pred = self.model(self.fact_graph, h_index, t_index, r_index, all_loss=all_loss, metric=metric)
137138
t_preds.append(t_pred)
138139
t_pred = torch.cat(t_preds, dim=-1)
139140
for neg_index in all_index.split(self.num_negative):
140141
r_index = pos_r_index.unsqueeze(-1).expand(-1, len(neg_index))
141142
t_index, h_index = torch.meshgrid(pos_t_index, neg_index)
142-
h_pred = self.model(self.fact_graph, h_index, t_index, r_index)
143+
h_pred = self.model(self.fact_graph, h_index, t_index, r_index, all_loss=all_loss, metric=metric)
143144
h_preds.append(h_pred)
144145
h_pred = torch.cat(h_preds, dim=-1)
145146
pred = torch.stack([t_pred, h_pred], dim=1)
@@ -162,32 +163,26 @@ def predict(self, batch, all_loss=None, metric=None):
162163

163164
def target(self, batch):
164165
# test target
166+
batch_size = len(batch)
165167
pos_h_index, pos_t_index, pos_r_index = batch.t()
166-
target = torch.stack([pos_t_index, pos_h_index], dim=1)
168+
any = -torch.ones_like(pos_h_index)
169+
170+
pattern = torch.stack([pos_h_index, any, pos_r_index], dim=-1)
171+
edge_index, num_t_truth = self.graph.match(pattern)
172+
t_truth_index = self.graph.edge_list[edge_index, 1]
173+
pos_index = functional._size_to_index(num_t_truth)
174+
t_mask = torch.ones(batch_size, self.num_entity, dtype=torch.bool, device=self.device)
175+
t_mask[pos_index, t_truth_index] = 0
176+
177+
pattern = torch.stack([any, pos_t_index, pos_r_index], dim=-1)
178+
edge_index, num_h_truth = self.graph.match(pattern)
179+
h_truth_index = self.graph.edge_list[edge_index, 0]
180+
pos_index = functional._size_to_index(num_h_truth)
181+
h_mask = torch.ones(batch_size, self.num_entity, dtype=torch.bool, device=self.device)
182+
h_mask[pos_index, h_truth_index] = 0
167183

168-
hr_index_set, hr_inverse = torch.unique(pos_h_index * self.num_relation + pos_r_index, return_inverse=True)
169-
tr_index_set, tr_inverse = torch.unique(pos_t_index * self.num_relation + pos_r_index, return_inverse=True)
170-
hr_index2id = -torch.ones(self.num_entity * self.num_relation, dtype=torch.long, device=self.device)
171-
hr_index2id[hr_index_set] = torch.arange(len(hr_index_set), device=self.device)
172-
tr_index2id = -torch.ones(self.num_entity * self.num_relation, dtype=torch.long, device=self.device)
173-
tr_index2id[tr_index_set] = torch.arange(len(tr_index_set), device=self.device)
174-
175-
h_index, t_index, r_index = self.graph.edge_list.t()
176-
hr_index = h_index * self.num_relation + r_index
177-
tr_index = t_index * self.num_relation + r_index
178-
valid = hr_index2id[hr_index] >= 0
179-
hr_index = hr_index[valid]
180-
t_index = t_index[valid]
181-
t_mask_set = torch.ones(len(hr_index_set), self.num_entity, dtype=torch.bool, device=self.device)
182-
t_mask_set[hr_index2id[hr_index], t_index] = 0
183-
t_mask = t_mask_set[hr_inverse]
184-
valid = tr_index2id[tr_index] >= 0
185-
tr_index = tr_index[valid]
186-
h_index = h_index[valid]
187-
h_mask_set = torch.ones(len(tr_index_set), self.num_entity, dtype=torch.bool, device=self.device)
188-
h_mask_set[tr_index2id[tr_index], h_index] = 0
189-
h_mask = h_mask_set[tr_inverse]
190184
mask = torch.stack([t_mask, h_mask], dim=1)
185+
target = torch.stack([pos_t_index, pos_h_index], dim=1)
191186

192187
# in case of GPU OOM
193188
return mask.cpu(), target.cpu()
@@ -225,47 +220,30 @@ def visualize(self, batch):
225220
@torch.no_grad()
226221
def _strict_negative(self, pos_h_index, pos_t_index, pos_r_index):
227222
batch_size = len(pos_h_index)
228-
229-
hr_index_set, hr_inverse = torch.unique(pos_h_index * self.num_relation + pos_r_index, return_inverse=True)
230-
tr_index_set, tr_inverse = torch.unique(pos_t_index * self.num_relation + pos_r_index, return_inverse=True)
231-
hr_index2id = -torch.ones(self.num_entity * self.num_relation, dtype=torch.long, device=self.device)
232-
hr_index2id[hr_index_set] = torch.arange(len(hr_index_set), device=self.device)
233-
tr_index2id = -torch.ones(self.num_entity * self.num_relation, dtype=torch.long, device=self.device)
234-
tr_index2id[tr_index_set] = torch.arange(len(tr_index_set), device=self.device)
235-
236-
h_index, t_index, r_index = self.fact_graph.edge_list.t()
237-
hr_index = h_index * self.num_relation + r_index
238-
tr_index = t_index * self.num_relation + r_index
239-
valid = hr_index2id[hr_index] >= 0
240-
hr_index = hr_index[valid]
241-
t_index = t_index[valid]
242-
t_mask_set = torch.ones(len(hr_index_set), self.num_entity, dtype=torch.bool, device=self.device)
243-
t_mask_set[hr_index2id[hr_index], t_index] = 0
244-
t_mask = t_mask_set[hr_inverse]
245-
valid = tr_index2id[tr_index] >= 0
246-
tr_index = tr_index[valid]
247-
h_index = h_index[valid]
248-
h_mask_set = torch.ones(len(tr_index_set), self.num_entity, dtype=torch.bool, device=self.device)
249-
h_mask_set[tr_index2id[tr_index], h_index] = 0
250-
h_mask = h_mask_set[tr_inverse]
251-
252-
num_neg_t = t_mask.sum(dim=-1, keepdim=True)
253-
num_neg_h = h_mask.sum(dim=-1, keepdim=True)
254-
num_cum_neg_t = num_neg_t.cumsum(0)
255-
num_cum_neg_h = num_neg_h.cumsum(0)
256-
257-
neg_t_index = t_mask.nonzero()[:, 1]
258-
neg_h_index = h_mask.nonzero()[:, 1]
259-
260-
rand = torch.rand(batch_size, self.num_negative, device=self.device)
261-
index = (rand[:batch_size // 2] * num_neg_t[:batch_size // 2]).long()
262-
index = index + (num_cum_neg_t[:batch_size // 2] - num_neg_t[:batch_size // 2])
263-
neg_index_t = neg_t_index[index]
264-
265-
index = (rand[batch_size // 2:] * num_neg_h[batch_size // 2:]).long()
266-
index = index + (num_cum_neg_h[batch_size // 2:] - num_neg_h[batch_size // 2:])
267-
neg_index_h = neg_h_index[index]
268-
269-
neg_index = torch.cat([neg_index_t, neg_index_h])
223+
any = -torch.ones_like(pos_h_index)
224+
225+
pattern = torch.stack([pos_h_index, any, pos_r_index], dim=-1)
226+
pattern = pattern[:batch_size // 2]
227+
edge_index, num_t_truth = self.fact_graph.match(pattern)
228+
t_truth_index = self.fact_graph.edge_list[edge_index, 1]
229+
pos_index = functional._size_to_index(num_t_truth)
230+
t_mask = torch.ones(len(pattern), self.num_entity, dtype=torch.bool, device=self.device)
231+
t_mask[pos_index, t_truth_index] = 0
232+
neg_t_candidate = t_mask.nonzero()[:, 1]
233+
num_t_candidate = t_mask.sum(dim=-1)
234+
neg_t_index = functional.variadic_sample(neg_t_candidate, num_t_candidate, self.num_negative)
235+
236+
pattern = torch.stack([any, pos_t_index, pos_r_index], dim=-1)
237+
pattern = pattern[batch_size // 2:]
238+
edge_index, num_h_truth = self.fact_graph.match(pattern)
239+
h_truth_index = self.fact_graph.edge_list[edge_index, 0]
240+
pos_index = functional._size_to_index(num_h_truth)
241+
h_mask = torch.ones(len(pattern), self.num_entity, dtype=torch.bool, device=self.device)
242+
h_mask[pos_index, h_truth_index] = 0
243+
neg_h_candidate = h_mask.nonzero()[:, 1]
244+
num_h_candidate = h_mask.sum(dim=-1)
245+
neg_h_index = functional.variadic_sample(neg_h_candidate, num_h_candidate, self.num_negative)
246+
247+
neg_index = torch.cat([neg_t_index, neg_h_index])
270248

271249
return neg_index

torchdrug/tasks/retrosynthesis.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -948,7 +948,7 @@ def predict_reactant(self, batch, num_beam=10, max_prediction=20, max_step=20):
948948
smiles_set.add(smiles)
949949
is_duplicate = torch.tensor(is_duplicate, device=self.device)
950950
result = result[~is_duplicate]
951-
num_prediction = torch.bincount(result.synthon_id)
951+
num_prediction = result.synthon_id.bincount(minlength=len(synthon))
952952

953953
# remove extra predictions
954954
topk = functional.variadic_topk(result.logp, num_prediction, max_prediction)[1]
@@ -1139,7 +1139,7 @@ def predict(self, batch, all_loss=None, metric=None):
11391139
order = product_ids.argsort()
11401140
logps = logps[order]
11411141
reactant_ids = reactant_ids[order]
1142-
num_prediction = torch.bincount(product_ids)
1142+
num_prediction = product_ids.bincount()
11431143
logps, topk = functional.variadic_topk(logps, num_prediction, self.max_prediction)
11441144
topk_index = topk + (num_prediction.cumsum(0) - num_prediction).unsqueeze(-1)
11451145
topk_index_shifted = torch.cat([-torch.ones(len(topk_index), 1, dtype=torch.long, device=self.device),
@@ -1170,7 +1170,7 @@ def predict(self, batch, all_loss=None, metric=None):
11701170
setattr(reactant, k, v)
11711171
reactant.logps = logps
11721172

1173-
num_prediction = torch.bincount(reactant.product_id)
1173+
num_prediction = reactant.product_id.bincount()
11741174

11751175
return reactant, num_prediction # (num_graph * k)
11761176

0 commit comments

Comments
 (0)